From c846a2206894158806a196d37198016ae0293cf3 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 28 Sep 2025 20:04:32 +0100 Subject: [PATCH 01/35] Get aggregates working! --- docs/quickstart.ipynb | 65 ++++-- src/policyengine/database/__init__.py | 9 + src/policyengine/database/aggregate.py | 74 ++++++- .../baseline_parameter_value_table.py | 106 ++++++--- .../database/baseline_variable_table.py | 61 ++++-- src/policyengine/database/database.py | 64 +++++- src/policyengine/database/dataset_table.py | 75 ++++++- src/policyengine/database/dynamic_table.py | 50 ++++- src/policyengine/database/link.py | 78 +------ src/policyengine/database/model_table.py | 45 +++- .../database/model_version_table.py | 55 ++++- src/policyengine/database/parameter_table.py | 79 ++++++- .../database/parameter_value_table.py | 106 ++++++--- src/policyengine/database/policy_table.py | 118 +++++++++- .../database/report_element_table.py | 49 +++++ src/policyengine/database/report_table.py | 96 +++++++++ src/policyengine/database/simulation_table.py | 203 ++++++++++++++++-- src/policyengine/database/table_mixin.py | 80 +++++++ src/policyengine/database/user_table.py | 29 +++ .../database/versioned_dataset_table.py | 25 ++- src/policyengine/models/__init__.py | 6 + src/policyengine/models/aggregate.py | 9 +- .../models/baseline_parameter_value.py | 4 +- src/policyengine/models/model.py | 2 + src/policyengine/models/parameter.py | 2 + src/policyengine/models/policyengine_uk.py | 3 +- src/policyengine/models/policyengine_us.py | 2 +- src/policyengine/models/report.py | 10 + src/policyengine/models/simulation.py | 1 + 29 files changed, 1255 insertions(+), 251 deletions(-) create mode 100644 src/policyengine/database/table_mixin.py diff --git a/docs/quickstart.ipynb b/docs/quickstart.ipynb index 9340cb1a..3932cd2a 100644 --- a/docs/quickstart.ipynb +++ b/docs/quickstart.ipynb @@ -66,13 +66,13 @@ "£500,000" ], "y": [ - 6628102.860910795, - 10308039.540624166, - 7153251.306053954, - 4288185.176098487, - 1690702.647548969, - 1320125.7573599513, - 326073.73102501093, + 6601567.352745674, + 10308068.997645264, + 7214847.8757442, + 4260541.368792315, + 1683256.5811573816, + 1320122.584893554, + 326076.2586429423, 187608.23132836912, 63106.63353048405, 41838.373842805624 @@ -1110,13 +1110,13 @@ "£500,000" ], "y": [ - 6628102.860910795, - 10308039.540624166, - 7153251.306053954, - 4288185.176098487, - 1690702.647548969, - 1320125.7573599513, - 326073.73102501093, + 6601567.352745674, + 10308068.997645264, + 7214847.8757442, + 4260541.368792315, + 1683256.5811573816, + 1320122.584893554, + 326076.2586429423, 187608.23132836912, 63106.63353048405, 41838.373842805624 @@ -1148,12 +1148,12 @@ "£500,000" ], "y": [ - 6172777.805479924, - 10310058.00384126, - 6911190.799784593, - 4471614.799692215, - 2005466.130918176, - 1471720.3202646417, + 6146242.297314803, + 10316900.228889126, + 6968299.661946169, + 4446106.99205335, + 1992976.835352243, + 1472301.8444251162, 341808.24952948757, 218180.35939976107, 63106.63353048405, @@ -2169,6 +2169,25 @@ { "cell_type": "code", "execution_count": 3, + "id": "ac6c443e", + "metadata": {}, + "outputs": [], + "source": [ + "from policyengine.database import Database\n", + "from policyengine.models.policyengine_uk import policyengine_uk_latest_version\n", + "\n", + "database = Database(\"postgresql://postgres:postgres@127.0.0.1:54322/postgres\")\n", + "\n", + "# These two lines are not usually needed, but you should use them the first time you set up a new database\n", + "database.reset() # Drop and recreate all tables\n", + "database.register_model_version(\n", + " policyengine_uk_latest_version\n", + ") # Add in the model, model version, parameters and baseline parameter values and variables.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "id": "f14c85eb", "metadata": {}, "outputs": [], @@ -2198,17 +2217,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "2041dfeb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Policy(id='26f30afa-77b9-4435-812c-071873e25400', name='Increase personal allowance to £20,000', description='A policy to increase the personal allowance for income tax to £20,000.', parameter_values=[], simulation_modifier=None, created_at=datetime.datetime(2025, 9, 20, 12, 36, 27, 162725), updated_at=datetime.datetime(2025, 9, 20, 12, 36, 27, 162729))" + "Policy(id='64792d73-b52a-431b-a22a-1b5cf3ae4551', name='Increase personal allowance to £20,000', description='A policy to increase the personal allowance for income tax to £20,000.', parameter_values=[ParameterValue(id='5ffda0b7-c70f-4cef-b45b-91fb663d12db', parameter=Parameter(id='gov.hmrc.income_tax.allowances.personal_allowance.amount', description=None, data_type=None, model=Model(id='policyengine_uk', name='PolicyEngine UK', description=\"PolicyEngine's open-source tax-benefit microsimulation model.\", simulation_function=), label=None, unit=None), value=20000, start_date=datetime.datetime(2029, 1, 1, 0, 0), end_date=None)], simulation_modifier=None, created_at=datetime.datetime(2025, 9, 28, 19, 57, 41, 515173), updated_at=datetime.datetime(2025, 9, 28, 19, 57, 41, 515176))" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } diff --git a/src/policyengine/database/__init__.py b/src/policyengine/database/__init__.py index 2490d15c..5efcae3a 100644 --- a/src/policyengine/database/__init__.py +++ b/src/policyengine/database/__init__.py @@ -25,6 +25,9 @@ VersionedDatasetTable, versioned_dataset_table_link, ) +from .report_table import ReportTable, report_table_link +from .report_element_table import ReportElementTable, report_element_table_link +from .aggregate import AggregateTable, aggregate_table_link __all__ = [ "Database", @@ -41,6 +44,9 @@ "BaselineParameterValueTable", "BaselineVariableTable", "SimulationTable", + "ReportTable", + "ReportElementTable", + "AggregateTable", # Links "model_table_link", "model_version_table_link", @@ -53,4 +59,7 @@ "baseline_parameter_value_table_link", "baseline_variable_table_link", "simulation_table_link", + "report_table_link", + "report_element_table_link", + "aggregate_table_link", ] diff --git a/src/policyengine/database/aggregate.py b/src/policyengine/database/aggregate.py index b945ea6c..44c8aacd 100644 --- a/src/policyengine/database/aggregate.py +++ b/src/policyengine/database/aggregate.py @@ -1,9 +1,14 @@ +from typing import TYPE_CHECKING from uuid import uuid4 from sqlmodel import Field, SQLModel from policyengine.database.link import TableLink from policyengine.models.aggregate import Aggregate +from policyengine.models import Simulation + +if TYPE_CHECKING: + from .database import Database class AggregateTable(SQLModel, table=True): @@ -21,13 +26,76 @@ class AggregateTable(SQLModel, table=True): filter_variable_leq: float | None = None filter_variable_geq: float | None = None aggregate_function: str + reportelement_id: str | None = None value: float | None = None + @classmethod + def convert_from_model(cls, model: Aggregate, database: "Database" = None) -> "AggregateTable": + """Convert an Aggregate instance to an AggregateTable instance. + + Args: + model: The Aggregate instance to convert + database: The database instance for persisting the simulation if needed + + Returns: + An AggregateTable instance + """ + # Don't try to save the simulation here - it's already being saved + # This prevents circular references + + return cls( + id=model.id, + simulation_id=model.simulation.id if model.simulation else None, + entity=model.entity, + variable_name=model.variable_name, + year=model.year, + filter_variable_name=model.filter_variable_name, + filter_variable_value=model.filter_variable_value, + filter_variable_leq=model.filter_variable_leq, + filter_variable_geq=model.filter_variable_geq, + aggregate_function=model.aggregate_function, + reportelement_id=model.reportelement_id, + value=model.value, + ) + + def convert_to_model(self, database: "Database" = None) -> Aggregate: + """Convert this AggregateTable instance to an Aggregate instance. + + Args: + database: The database instance for resolving the simulation foreign key + + Returns: + An Aggregate instance + """ + from .simulation_table import SimulationTable + from sqlmodel import select + + # Resolve the simulation foreign key + simulation = None + if database and self.simulation_id: + sim_table = database.session.exec( + select(SimulationTable).where(SimulationTable.id == self.simulation_id) + ).first() + if sim_table: + simulation = sim_table.convert_to_model(database) + + return Aggregate( + id=self.id, + simulation=simulation, + entity=self.entity, + variable_name=self.variable_name, + year=self.year, + filter_variable_name=self.filter_variable_name, + filter_variable_value=self.filter_variable_value, + filter_variable_leq=self.filter_variable_leq, + filter_variable_geq=self.filter_variable_geq, + aggregate_function=self.aggregate_function, + reportelement_id=self.reportelement_id, + value=self.value, + ) + aggregate_table_link = TableLink( model_cls=Aggregate, table_cls=AggregateTable, - model_to_table_custom_transforms=dict( - simulation_id=lambda a: a.simulation.id, - ), ) diff --git a/src/policyengine/database/baseline_parameter_value_table.py b/src/policyengine/database/baseline_parameter_value_table.py index 49282996..6485223c 100644 --- a/src/policyengine/database/baseline_parameter_value_table.py +++ b/src/policyengine/database/baseline_parameter_value_table.py @@ -3,11 +3,15 @@ from uuid import uuid4 from sqlmodel import JSON, Column, Field, SQLModel +from typing import TYPE_CHECKING -from policyengine.models import BaselineParameterValue +from policyengine.models import ModelVersion, Parameter, BaselineParameterValue from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class BaselineParameterValueTable(SQLModel, table=True): __tablename__ = "baseline_parameter_values" @@ -25,42 +29,84 @@ class BaselineParameterValueTable(SQLModel, table=True): start_date: datetime = Field(nullable=False) end_date: datetime | None = Field(default=None) + @classmethod + def convert_from_model(cls, model: BaselineParameterValue, database: "Database" = None) -> "BaselineParameterValueTable": + """Convert a BaselineParameterValue instance to a BaselineParameterValueTable instance.""" + import math + + # Ensure foreign objects are persisted if database is provided + if database: + if model.parameter: + database.set(model.parameter, commit=False) + if model.model_version: + database.set(model.model_version, commit=False) + + # Handle special float values + value = model.value + if isinstance(value, float): + if math.isinf(value): + value = "Infinity" if value > 0 else "-Infinity" + elif math.isnan(value): + value = "NaN" + + return cls( + id=model.id, + parameter_id=model.parameter.id if model.parameter else None, + model_id=model.parameter.model.id if model.parameter and model.parameter.model else None, + model_version_id=model.model_version.id if model.model_version else None, + value=value, + start_date=model.start_date, + end_date=model.end_date, + ) + + def convert_to_model(self, database: "Database" = None) -> BaselineParameterValue: + """Convert this BaselineParameterValueTable instance to a BaselineParameterValue instance.""" + from .parameter_table import ParameterTable + from .model_version_table import ModelVersionTable + from sqlmodel import select + + # Resolve foreign keys + parameter = None + model_version = None -def transform_value_to_table(bpv): - """Transform value for storage, handling special float values.""" - import math + if database: + if self.parameter_id and self.model_id: + param_table = database.session.exec( + select(ParameterTable).where( + ParameterTable.id == self.parameter_id, + ParameterTable.model_id == self.model_id + ) + ).first() + if param_table: + parameter = param_table.convert_to_model(database) - value = bpv.value - if isinstance(value, float): - if math.isinf(value): - return "Infinity" if value > 0 else "-Infinity" - elif math.isnan(value): - return "NaN" - return value + if self.model_version_id: + version_table = database.session.exec( + select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) + ).first() + if version_table: + model_version = version_table.convert_to_model(database) + # Handle special string values + value = self.value + if value == "Infinity": + value = float("inf") + elif value == "-Infinity": + value = float("-inf") + elif value == "NaN": + value = float("nan") -def transform_value_from_table(table_row): - """Transform value from storage, converting special strings back to floats.""" - value = table_row.value - if value == "Infinity": - return float("inf") - elif value == "-Infinity": - return float("-inf") - elif value == "NaN": - return float("nan") - return value + return BaselineParameterValue( + id=self.id, + parameter=parameter, + model_version=model_version, + value=value, + start_date=self.start_date, + end_date=self.end_date, + ) baseline_parameter_value_table_link = TableLink( model_cls=BaselineParameterValue, table_cls=BaselineParameterValueTable, - model_to_table_custom_transforms=dict( - parameter_id=lambda bpv: bpv.parameter.id, - model_id=lambda bpv: bpv.parameter.model.id, # Add model_id from parameter - model_version_id=lambda bpv: bpv.model_version.id, - value=transform_value_to_table, - ), - table_to_model_custom_transforms=dict( - value=transform_value_from_table, - ), ) diff --git a/src/policyengine/database/baseline_variable_table.py b/src/policyengine/database/baseline_variable_table.py index 6f7e61e5..6d836ee2 100644 --- a/src/policyengine/database/baseline_variable_table.py +++ b/src/policyengine/database/baseline_variable_table.py @@ -1,10 +1,14 @@ from sqlmodel import Field, SQLModel +from typing import TYPE_CHECKING -from policyengine.models import BaselineVariable +from policyengine.models import ModelVersion, BaselineVariable from policyengine.utils.compress import compress_data, decompress_data from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class BaselineVariableTable(SQLModel, table=True): __tablename__ = "baseline_variables" @@ -22,19 +26,52 @@ class BaselineVariableTable(SQLModel, table=True): description: str | None = Field(default=None) data_type: bytes | None = Field(default=None) # Pickled type + @classmethod + def convert_from_model(cls, model: BaselineVariable, database: "Database" = None) -> "BaselineVariableTable": + """Convert a BaselineVariable instance to a BaselineVariableTable instance.""" + from policyengine.utils.compress import compress_data + + # Ensure foreign objects are persisted if database is provided + if database and model.model_version: + database.set(model.model_version, commit=False) + + return cls( + id=model.id, + model_id=model.model_version.model.id if model.model_version and model.model_version.model else None, + model_version_id=model.model_version.id if model.model_version else None, + entity=model.entity, + label=model.label, + description=model.description, + data_type=compress_data(model.data_type) if model.data_type else None, + ) + + def convert_to_model(self, database: "Database" = None) -> BaselineVariable: + """Convert this BaselineVariableTable instance to a BaselineVariable instance.""" + from policyengine.utils.compress import decompress_data + from .model_version_table import ModelVersionTable + from sqlmodel import select + + # Resolve foreign keys + model_version = None + + if database and self.model_version_id: + version_table = database.session.exec( + select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) + ).first() + if version_table: + model_version = version_table.convert_to_model(database) + + return BaselineVariable( + id=self.id, + model_version=model_version, + entity=self.entity, + label=self.label, + description=self.description, + data_type=decompress_data(self.data_type) if self.data_type else None, + ) + baseline_variable_table_link = TableLink( model_cls=BaselineVariable, table_cls=BaselineVariableTable, - primary_key=("id", "model_id"), # Composite primary key - model_to_table_custom_transforms=dict( - model_id=lambda bv: bv.model_version.model.id, # Add model_id from model_version - model_version_id=lambda bv: bv.model_version.id, - data_type=lambda bv: compress_data(bv.data_type) - if bv.data_type - else None, - ), - table_to_model_custom_transforms=dict( - data_type=lambda dt: decompress_data(dt) if dt else None, - ), ) diff --git a/src/policyengine/database/database.py b/src/policyengine/database/database.py index 61b03b83..2ae77e1c 100644 --- a/src/policyengine/database/database.py +++ b/src/policyengine/database/database.py @@ -88,6 +88,10 @@ def register_table(self, link: TableLink): link.table_cls.metadata.create_all(self.engine) def get(self, model_cls: type, **kwargs): + """Get a model instance from the database by its attributes.""" + from sqlmodel import select + + # Find the table class for this model table_link = next( ( link @@ -96,10 +100,26 @@ def get(self, model_cls: type, **kwargs): ), None, ) - if table_link is not None: - return table_link.get(self, **kwargs) + + if table_link is None: + return None + + # Query the database + statement = select(table_link.table_cls).filter_by(**kwargs) + result = self.session.exec(statement).first() + + if result is None: + return None + + # Use the table's convert_to_model method + return result.convert_to_model(self) def set(self, object: Any, commit: bool = True): + """Save or update a model instance in the database.""" + from sqlmodel import select + from sqlalchemy.inspection import inspect + + # Find the table class for this model table_link = next( ( link @@ -108,8 +128,36 @@ def set(self, object: Any, commit: bool = True): ), None, ) - if table_link is not None: - table_link.set(self, object, commit=commit) + + if table_link is None: + return + + # Convert model to table instance + table_obj = table_link.table_cls.convert_from_model(object, self) + + # Get primary key columns + mapper = inspect(table_link.table_cls) + pk_cols = [col.name for col in mapper.primary_key] + + # Build query to check if exists + query = select(table_link.table_cls) + for pk_col in pk_cols: + query = query.where( + getattr(table_link.table_cls, pk_col) == getattr(table_obj, pk_col) + ) + + existing = self.session.exec(query).first() + + if existing: + # Update existing record + for key, value in table_obj.model_dump().items(): + setattr(existing, key, value) + self.session.add(existing) + else: + self.session.add(table_obj) + + if commit: + self.session.commit() def register_model_version(self, model_version): """Register a model version with its model and seed objects. @@ -134,9 +182,9 @@ def register_model_version(self, model_version): id=model_version.model.id, name=model_version.model.name, description=model_version.model.description, - simulation_function=( - lambda m: compress_data(m.simulation_function) - )(model_version.model), + simulation_function=compress_data( + model_version.model.simulation_function + ), ) self.session.add(model_table) self.session.flush() @@ -194,6 +242,8 @@ def register_model_version(self, model_version): data_type=parameter.data_type.__name__ if parameter.data_type else None, + label=parameter.label, + unit=parameter.unit, ) self.session.add(param_table) diff --git a/src/policyengine/database/dataset_table.py b/src/policyengine/database/dataset_table.py index 4eb4156c..cf22cda8 100644 --- a/src/policyengine/database/dataset_table.py +++ b/src/policyengine/database/dataset_table.py @@ -1,12 +1,16 @@ +from typing import TYPE_CHECKING from uuid import uuid4 from sqlmodel import Field, SQLModel -from policyengine.models import Dataset +from policyengine.models import Dataset, Model, VersionedDataset from policyengine.utils.compress import compress_data, decompress_data from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class DatasetTable(SQLModel, table=True): __tablename__ = "datasets" @@ -24,18 +28,67 @@ class DatasetTable(SQLModel, table=True): default=None, foreign_key="models.id", ondelete="SET NULL" ) + @classmethod + def convert_from_model(cls, model: Dataset, database: "Database" = None) -> "DatasetTable": + """Convert a Dataset instance to a DatasetTable instance. + + Args: + model: The Dataset instance to convert + database: The database instance for persisting foreign objects if needed + + Returns: + A DatasetTable instance + """ + # Ensure foreign objects are persisted if database is provided + if database: + if model.versioned_dataset: + database.set(model.versioned_dataset, commit=False) + if model.model: + database.set(model.model, commit=False) + + return cls( + id=model.id, + name=model.name, + description=model.description, + version=model.version, + versioned_dataset_id=model.versioned_dataset.id if model.versioned_dataset else None, + year=model.year, + data=compress_data(model.data) if model.data else None, + model_id=model.model.id if model.model else None, + ) + + def convert_to_model(self, database: "Database" = None) -> Dataset: + """Convert this DatasetTable instance to a Dataset instance. + + Args: + database: The database instance for resolving foreign keys + + Returns: + A Dataset instance + """ + # Resolve foreign keys + versioned_dataset = None + model = None + + if database: + if self.versioned_dataset_id: + versioned_dataset = database.get(VersionedDataset, id=self.versioned_dataset_id) + if self.model_id: + model = database.get(Model, id=self.model_id) + + return Dataset( + id=self.id, + name=self.name, + description=self.description, + version=self.version, + versioned_dataset=versioned_dataset, + year=self.year, + data=decompress_data(self.data) if self.data else None, + model=model, + ) + dataset_table_link = TableLink( model_cls=Dataset, table_cls=DatasetTable, - model_to_table_custom_transforms=dict( - versioned_dataset_id=lambda d: d.versioned_dataset.id - if d.versioned_dataset - else None, - model_id=lambda d: d.model.id if d.model else None, - data=lambda d: compress_data(d.data) if d.data else None, - ), - table_to_model_custom_transforms=dict( - data=lambda b: decompress_data(b) if b else None, - ), ) diff --git a/src/policyengine/database/dynamic_table.py b/src/policyengine/database/dynamic_table.py index 6d510afe..086e6bd9 100644 --- a/src/policyengine/database/dynamic_table.py +++ b/src/policyengine/database/dynamic_table.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import TYPE_CHECKING from uuid import uuid4 from sqlmodel import Field, SQLModel @@ -8,6 +9,9 @@ from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class DynamicTable(SQLModel, table=True): __tablename__ = "dynamics" @@ -19,16 +23,46 @@ class DynamicTable(SQLModel, table=True): created_at: datetime = Field(default_factory=datetime.now) updated_at: datetime = Field(default_factory=datetime.now) + @classmethod + def convert_from_model(cls, model: Dynamic, database: "Database" = None) -> "DynamicTable": + """Convert a Dynamic instance to a DynamicTable instance. + + Args: + model: The Dynamic instance to convert + database: The database instance (not used for this table) + + Returns: + A DynamicTable instance + """ + return cls( + id=model.id, + name=model.name, + description=model.description, + simulation_modifier=compress_data(model.simulation_modifier) if model.simulation_modifier else None, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + def convert_to_model(self, database: "Database" = None) -> Dynamic: + """Convert this DynamicTable instance to a Dynamic instance. + + Args: + database: The database instance (not used for this table) + + Returns: + A Dynamic instance + """ + return Dynamic( + id=self.id, + name=self.name, + description=self.description, + simulation_modifier=decompress_data(self.simulation_modifier) if self.simulation_modifier else None, + created_at=self.created_at, + updated_at=self.updated_at, + ) + dynamic_table_link = TableLink( model_cls=Dynamic, table_cls=DynamicTable, - model_to_table_custom_transforms=dict( - simulation_modifier=lambda d: compress_data(d.simulation_modifier) - if d.simulation_modifier - else None, - ), - table_to_model_custom_transforms=dict( - simulation_modifier=lambda b: decompress_data(b) if b else None, - ), ) diff --git a/src/policyengine/database/link.py b/src/policyengine/database/link.py index f6f19da8..2bb1a041 100644 --- a/src/policyengine/database/link.py +++ b/src/policyengine/database/link.py @@ -1,82 +1,8 @@ -from collections.abc import Callable -from typing import TYPE_CHECKING - from pydantic import BaseModel -from sqlmodel import SQLModel, select - -if TYPE_CHECKING: - from .database import Database +from sqlmodel import SQLModel class TableLink(BaseModel): + """Simple registry mapping model classes to table classes.""" model_cls: type[BaseModel] table_cls: type[SQLModel] - model_to_table_custom_transforms: dict[str, Callable] | None = None - table_to_model_custom_transforms: dict[str, Callable] | None = None - primary_key: str | tuple[str, ...] = ( - "id" # Allow multiple strings in tuple - ) - - def get(self, database: "Database", **kwargs): - statement = select(self.table_cls).filter_by(**kwargs) - result = database.session.exec(statement).first() - if result is None: - return None - model_data = result.model_dump() - if self.table_to_model_custom_transforms: - for ( - field, - transform, - ) in self.table_to_model_custom_transforms.items(): - model_data[field] = transform(getattr(result, field)) - - # Only include fields that exist in the model class - valid_fields = { - field_name for field_name in self.model_cls.__annotations__.keys() - } - filtered_model_data = { - k: v for k, v in model_data.items() if k in valid_fields - } - return self.model_cls(**filtered_model_data) - - def set(self, database: "Database", obj: BaseModel, commit: bool = True): - model_data = obj.model_dump() - if self.model_to_table_custom_transforms: - for ( - field, - transform, - ) in self.model_to_table_custom_transforms.items(): - model_data[field] = transform(obj) - # Only include fields that exist in the table class - valid_fields = { - field_name for field_name in self.table_cls.__annotations__.keys() - } - filtered_model_data = { - k: v for k, v in model_data.items() if k in valid_fields - } - table_obj = self.table_cls(**filtered_model_data) - - # Check if already exists using primary key - query = select(self.table_cls) - if isinstance(self.primary_key, tuple): - for key in self.primary_key: - query = query.where( - getattr(self.table_cls, key) == getattr(table_obj, key) - ) - else: - query = query.where( - getattr(self.table_cls, self.primary_key) - == getattr(table_obj, self.primary_key) - ) - - existing = database.session.exec(query).first() - if existing: - # Update existing record - for key, value in filtered_model_data.items(): - setattr(existing, key, value) - database.session.add(existing) - else: - database.session.add(table_obj) - - if commit: - database.session.commit() diff --git a/src/policyengine/database/model_table.py b/src/policyengine/database/model_table.py index 40d4d2e8..220238c8 100644 --- a/src/policyengine/database/model_table.py +++ b/src/policyengine/database/model_table.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + from sqlmodel import Field, SQLModel from policyengine.models import Model @@ -5,6 +7,9 @@ from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class ModelTable(SQLModel, table=True, extend_existing=True): __tablename__ = "models" @@ -14,14 +19,42 @@ class ModelTable(SQLModel, table=True, extend_existing=True): description: str | None = Field(default=None) simulation_function: bytes + @classmethod + def convert_from_model(cls, model: Model, database: "Database" = None) -> "ModelTable": + """Convert a Model instance to a ModelTable instance. + + Args: + model: The Model instance to convert + database: The database instance (not used for this table) + + Returns: + A ModelTable instance + """ + return cls( + id=model.id, + name=model.name, + description=model.description, + simulation_function=compress_data(model.simulation_function), + ) + + def convert_to_model(self, database: "Database" = None) -> Model: + """Convert this ModelTable instance to a Model instance. + + Args: + database: The database instance (not used for this table) + + Returns: + A Model instance + """ + return Model( + id=self.id, + name=self.name, + description=self.description, + simulation_function=decompress_data(self.simulation_function), + ) + model_table_link = TableLink( model_cls=Model, table_cls=ModelTable, - model_to_table_custom_transforms=dict( - simulation_function=lambda m: compress_data(m.simulation_function), - ), - table_to_model_custom_transforms=dict( - simulation_function=lambda b: decompress_data(b), - ), ) diff --git a/src/policyengine/database/model_version_table.py b/src/policyengine/database/model_version_table.py index fe590ec9..86d19fed 100644 --- a/src/policyengine/database/model_version_table.py +++ b/src/policyengine/database/model_version_table.py @@ -1,12 +1,16 @@ from datetime import datetime +from typing import TYPE_CHECKING from uuid import uuid4 from sqlmodel import Field, SQLModel -from policyengine.models import ModelVersion +from policyengine.models import Model, ModelVersion from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class ModelVersionTable(SQLModel, table=True): __tablename__ = "model_versions" @@ -17,12 +21,53 @@ class ModelVersionTable(SQLModel, table=True): description: str | None = Field(default=None) created_at: datetime = Field(default_factory=datetime.now) + @classmethod + def convert_from_model(cls, model: ModelVersion, database: "Database" = None) -> "ModelVersionTable": + """Convert a ModelVersion instance to a ModelVersionTable instance. + + Args: + model: The ModelVersion instance to convert + database: The database instance for persisting the model if needed + + Returns: + A ModelVersionTable instance + """ + # Ensure the Model is persisted if database is provided + if database and model.model: + database.set(model.model, commit=False) + + return cls( + id=model.id, + model_id=model.model.id if model.model else None, + version=model.version, + description=model.description, + created_at=model.created_at, + ) + + def convert_to_model(self, database: "Database" = None) -> ModelVersion: + """Convert this ModelVersionTable instance to a ModelVersion instance. + + Args: + database: The database instance for resolving the model foreign key + + Returns: + A ModelVersion instance + """ + # Resolve the model foreign key + model = None + if database and self.model_id: + model = database.get(Model, id=self.model_id) + + return ModelVersion( + id=self.id, + model=model, + version=self.version, + description=self.description, + created_at=self.created_at, + ) + model_version_table_link = TableLink( model_cls=ModelVersion, table_cls=ModelVersionTable, - model_to_table_custom_transforms=dict( - model_id=lambda model_version: model_version.model.id, - ), - table_to_model_custom_transforms={}, ) diff --git a/src/policyengine/database/parameter_table.py b/src/policyengine/database/parameter_table.py index 500484e1..aef88e5a 100644 --- a/src/policyengine/database/parameter_table.py +++ b/src/policyengine/database/parameter_table.py @@ -1,9 +1,14 @@ +from typing import TYPE_CHECKING + from sqlmodel import Field, SQLModel -from policyengine.models import Parameter +from policyengine.models import Model, Parameter from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class ParameterTable(SQLModel, table=True): __tablename__ = "parameters" @@ -15,17 +20,73 @@ class ParameterTable(SQLModel, table=True): ) # Part of composite key description: str | None = Field(default=None) data_type: str | None = Field(nullable=True) # Data type name + label: str | None = Field(default=None) + unit: str | None = Field(default=None) + + @classmethod + def convert_from_model(cls, model: Parameter, database: "Database" = None) -> "ParameterTable": + """Convert a Parameter instance to a ParameterTable instance. + + Args: + model: The Parameter instance to convert + database: The database instance for persisting the model if needed + + Returns: + A ParameterTable instance + """ + # Ensure the Model is persisted if database is provided + if database and model.model: + database.set(model.model, commit=False) + + return cls( + id=model.id, + model_id=model.model.id if model.model else None, + description=model.description, + data_type=model.data_type.__name__ if model.data_type else None, + label=model.label, + unit=model.unit, + ) + + def convert_to_model(self, database: "Database" = None) -> Parameter: + """Convert this ParameterTable instance to a Parameter instance. + + Args: + database: The database instance for resolving the model foreign key + + Returns: + A Parameter instance + """ + from .model_table import ModelTable + from sqlmodel import select + + # Resolve the model foreign key + model = None + if database and self.model_id: + model_table = database.session.exec( + select(ModelTable).where(ModelTable.id == self.model_id) + ).first() + if model_table: + model = model_table.convert_to_model(database) + + # Convert data_type string back to type + data_type = None + if self.data_type: + try: + data_type = eval(self.data_type) + except: + data_type = None + + return Parameter( + id=self.id, + description=self.description, + data_type=data_type, + model=model, + label=self.label, + unit=self.unit, + ) parameter_table_link = TableLink( model_cls=Parameter, table_cls=ParameterTable, - primary_key=("id", "model_id"), # Composite primary key - model_to_table_custom_transforms=dict( - data_type=lambda p: p.data_type.__name__ if p.data_type else None, - model_id=lambda p: p.model.id if p.model else None, - ), - table_to_model_custom_transforms=dict( - data_type=lambda t: eval(t.data_type) if t.data_type else None - ), ) diff --git a/src/policyengine/database/parameter_value_table.py b/src/policyengine/database/parameter_value_table.py index 1bdc19c2..7bd02d0a 100644 --- a/src/policyengine/database/parameter_value_table.py +++ b/src/policyengine/database/parameter_value_table.py @@ -1,13 +1,16 @@ from datetime import datetime -from typing import Any +from typing import TYPE_CHECKING, Any from uuid import uuid4 from sqlmodel import JSON, Column, Field, SQLModel -from policyengine.models import ParameterValue +from policyengine.models import Parameter, ParameterValue from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class ParameterValueTable(SQLModel, table=True): __tablename__ = "parameter_values" @@ -16,47 +19,90 @@ class ParameterValueTable(SQLModel, table=True): id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) parameter_id: str = Field(nullable=False) # Part of composite foreign key model_id: str = Field(nullable=False) # Part of composite foreign key + policy_id: str | None = Field(default=None, foreign_key="policies.id", ondelete="CASCADE") # Link to policy value: Any | None = Field( default=None, sa_column=Column(JSON) ) # JSON field for any type start_date: datetime = Field(nullable=False) end_date: datetime | None = Field(default=None) + @classmethod + def convert_from_model(cls, model: ParameterValue, database: "Database" = None) -> "ParameterValueTable": + """Convert a ParameterValue instance to a ParameterValueTable instance. + + Args: + model: The ParameterValue instance to convert + database: The database instance for persisting the parameter if needed + + Returns: + A ParameterValueTable instance + """ + import math + + # Ensure the Parameter is persisted if database is provided + if database and model.parameter: + database.set(model.parameter, commit=False) + + # Handle special float values + value = model.value + if isinstance(value, float): + if math.isinf(value): + value = "Infinity" if value > 0 else "-Infinity" + elif math.isnan(value): + value = "NaN" + + return cls( + id=model.id, + parameter_id=model.parameter.id if model.parameter else None, + model_id=model.parameter.model.id if model.parameter and model.parameter.model else None, + value=value, + start_date=model.start_date, + end_date=model.end_date, + ) + + def convert_to_model(self, database: "Database" = None) -> ParameterValue: + """Convert this ParameterValueTable instance to a ParameterValue instance. + + Args: + database: The database instance for resolving the parameter foreign key -def transform_value_to_table(pv): - """Transform value for storage, handling special float values.""" - import math + Returns: + A ParameterValue instance + """ + from .parameter_table import ParameterTable + from sqlmodel import select - value = pv.value - if isinstance(value, float): - if math.isinf(value): - return "Infinity" if value > 0 else "-Infinity" - elif math.isnan(value): - return "NaN" - return value + # Resolve the parameter foreign key + parameter = None + if database and self.parameter_id and self.model_id: + param_table = database.session.exec( + select(ParameterTable).where( + ParameterTable.id == self.parameter_id, + ParameterTable.model_id == self.model_id + ) + ).first() + if param_table: + parameter = param_table.convert_to_model(database) + # Handle special string values + value = self.value + if value == "Infinity": + value = float("inf") + elif value == "-Infinity": + value = float("-inf") + elif value == "NaN": + value = float("nan") -def transform_value_from_table(table_row): - """Transform value from storage, converting special strings back to floats.""" - value = table_row.value - if value == "Infinity": - return float("inf") - elif value == "-Infinity": - return float("-inf") - elif value == "NaN": - return float("nan") - return value + return ParameterValue( + id=self.id, + parameter=parameter, + value=value, + start_date=self.start_date, + end_date=self.end_date, + ) parameter_value_table_link = TableLink( model_cls=ParameterValue, table_cls=ParameterValueTable, - model_to_table_custom_transforms=dict( - parameter_id=lambda pv: pv.parameter.id, - model_id=lambda pv: pv.parameter.model.id, # Add model_id from parameter - value=transform_value_to_table, - ), - table_to_model_custom_transforms=dict( - value=transform_value_from_table, - ), ) diff --git a/src/policyengine/database/policy_table.py b/src/policyengine/database/policy_table.py index b8ce5a88..0ae381e4 100644 --- a/src/policyengine/database/policy_table.py +++ b/src/policyengine/database/policy_table.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import TYPE_CHECKING from uuid import uuid4 from sqlmodel import Field, SQLModel @@ -8,6 +9,9 @@ from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class PolicyTable(SQLModel, table=True): __tablename__ = "policies" @@ -19,16 +23,114 @@ class PolicyTable(SQLModel, table=True): created_at: datetime = Field(default_factory=datetime.now) updated_at: datetime = Field(default_factory=datetime.now) + @classmethod + def convert_from_model(cls, model: Policy, database: "Database" = None) -> "PolicyTable": + """Convert a Policy instance to a PolicyTable instance. + + Args: + model: The Policy instance to convert + database: The database instance for persisting nested objects + + Returns: + A PolicyTable instance + """ + policy_table = cls( + id=model.id, + name=model.name, + description=model.description, + simulation_modifier=compress_data(model.simulation_modifier) if model.simulation_modifier else None, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + # Handle nested parameter values if database is provided + if database and model.parameter_values: + from .parameter_value_table import ParameterValueTable + from sqlmodel import select + + # First ensure the policy table is saved to the database + # This is necessary so the foreign key constraint is satisfied + # Check if it already exists + existing_policy = database.session.exec( + select(PolicyTable).where(PolicyTable.id == model.id) + ).first() + + if not existing_policy: + database.session.add(policy_table) + database.session.flush() + + # Track which parameter value IDs we want to keep + desired_pv_ids = {pv.id for pv in model.parameter_values} + + # Delete only parameter values linked to this policy that are NOT in the new list + existing_pvs = database.session.exec( + select(ParameterValueTable).where(ParameterValueTable.policy_id == model.id) + ).all() + for pv in existing_pvs: + if pv.id not in desired_pv_ids: + database.session.delete(pv) + + # Now save/update the parameter values + for param_value in model.parameter_values: + # Check if this parameter value already exists in the database + existing_pv = database.session.exec( + select(ParameterValueTable).where(ParameterValueTable.id == param_value.id) + ).first() + + if existing_pv: + # Update existing parameter value + pv_table = ParameterValueTable.convert_from_model(param_value, database) + existing_pv.parameter_id = pv_table.parameter_id + existing_pv.model_id = pv_table.model_id + existing_pv.policy_id = model.id + existing_pv.value = pv_table.value + existing_pv.start_date = pv_table.start_date + existing_pv.end_date = pv_table.end_date + else: + # Create new parameter value + pv_table = ParameterValueTable.convert_from_model(param_value, database) + pv_table.policy_id = model.id # Link to this policy + database.session.add(pv_table) + database.session.flush() + + return policy_table + + def convert_to_model(self, database: "Database" = None) -> Policy: + """Convert this PolicyTable instance to a Policy instance. + + Args: + database: The database instance for loading nested objects + + Returns: + A Policy instance + """ + # Load nested parameter values if database is provided + parameter_values = [] + if database: + from .parameter_value_table import ParameterValueTable + from sqlmodel import select + + # Query for all parameter values linked to this policy + pv_tables = database.session.exec( + select(ParameterValueTable).where(ParameterValueTable.policy_id == self.id) + ).all() + + # Convert each one to a model + for pv_table in pv_tables: + parameter_values.append(pv_table.convert_to_model(database)) + + return Policy( + id=self.id, + name=self.name, + description=self.description, + parameter_values=parameter_values, + simulation_modifier=decompress_data(self.simulation_modifier) if self.simulation_modifier else None, + created_at=self.created_at, + updated_at=self.updated_at, + ) + policy_table_link = TableLink( model_cls=Policy, table_cls=PolicyTable, - model_to_table_custom_transforms=dict( - simulation_modifier=lambda p: compress_data(p.simulation_modifier) - if p.simulation_modifier - else None, - ), - table_to_model_custom_transforms=dict( - simulation_modifier=lambda b: decompress_data(b) if b else None, - ), ) diff --git a/src/policyengine/database/report_element_table.py b/src/policyengine/database/report_element_table.py index 477bfcc3..3db4e481 100644 --- a/src/policyengine/database/report_element_table.py +++ b/src/policyengine/database/report_element_table.py @@ -2,11 +2,15 @@ from datetime import datetime from sqlmodel import Field, SQLModel +from typing import TYPE_CHECKING from policyengine.models.report_element import ReportElement from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class ReportElementTable(SQLModel, table=True, extend_existing=True): __tablename__ = "report_elements" @@ -41,6 +45,51 @@ class ReportElementTable(SQLModel, table=True, extend_existing=True): created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) + @classmethod + def convert_from_model(cls, model: ReportElement, database: "Database" = None) -> "ReportElementTable": + """Convert a ReportElement instance to a ReportElementTable instance.""" + return cls( + id=model.id, + label=model.label, + type=model.type, + data_table=model.data_table, + chart_type=model.chart_type, + x_axis_variable=model.x_axis_variable, + y_axis_variable=model.y_axis_variable, + group_by=model.group_by, + color_by=model.color_by, + size_by=model.size_by, + markdown_content=model.markdown_content, + report_id=model.report_id, + user_id=model.user_id, + position=model.position, + visible=model.visible, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + def convert_to_model(self, database: "Database" = None) -> ReportElement: + """Convert this ReportElementTable instance to a ReportElement instance.""" + return ReportElement( + id=self.id, + label=self.label, + type=self.type, + data_table=self.data_table, + chart_type=self.chart_type, + x_axis_variable=self.x_axis_variable, + y_axis_variable=self.y_axis_variable, + group_by=self.group_by, + color_by=self.color_by, + size_by=self.size_by, + markdown_content=self.markdown_content, + report_id=self.report_id, + user_id=self.user_id, + position=self.position, + visible=self.visible, + created_at=self.created_at, + updated_at=self.updated_at, + ) + report_element_table_link = TableLink( model_cls=ReportElement, diff --git a/src/policyengine/database/report_table.py b/src/policyengine/database/report_table.py index 9ac473b5..79c11cf0 100644 --- a/src/policyengine/database/report_table.py +++ b/src/policyengine/database/report_table.py @@ -2,11 +2,15 @@ from datetime import datetime from sqlmodel import Field, SQLModel +from typing import TYPE_CHECKING from policyengine.models.report import Report from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class ReportTable(SQLModel, table=True, extend_existing=True): __tablename__ = "reports" @@ -17,6 +21,98 @@ class ReportTable(SQLModel, table=True, extend_existing=True): label: str = Field(nullable=False) created_at: datetime = Field(default_factory=datetime.utcnow) + @classmethod + def convert_from_model(cls, model: Report, database: "Database" = None) -> "ReportTable": + """Convert a Report instance to a ReportTable instance.""" + report_table = cls( + id=model.id, + label=model.label, + created_at=model.created_at, + ) + + # Handle nested report elements if database is provided + if database and model.elements: + from .report_element_table import ReportElementTable + from sqlmodel import select + + # First ensure the report table is saved to the database + # This is necessary so the foreign key constraint is satisfied + # Check if it already exists + existing_report = database.session.exec( + select(ReportTable).where(ReportTable.id == model.id) + ).first() + + if not existing_report: + database.session.add(report_table) + database.session.flush() + + # Track which element IDs we want to keep + desired_elem_ids = {elem.id for elem in model.elements} + + # Delete only elements linked to this report that are NOT in the new list + existing_elems = database.session.exec( + select(ReportElementTable).where(ReportElementTable.report_id == model.id) + ).all() + for elem in existing_elems: + if elem.id not in desired_elem_ids: + database.session.delete(elem) + + # Now save/update the elements + for i, element in enumerate(model.elements): + # Check if this element already exists in the database + existing_elem = database.session.exec( + select(ReportElementTable).where(ReportElementTable.id == element.id) + ).first() + + if existing_elem: + # Update existing element + elem_table = ReportElementTable.convert_from_model(element, database) + existing_elem.report_id = model.id + existing_elem.position = i + existing_elem.label = elem_table.label + existing_elem.type = elem_table.type + existing_elem.markdown_content = elem_table.markdown_content + existing_elem.chart_type = elem_table.chart_type + existing_elem.x_axis_variable = elem_table.x_axis_variable + existing_elem.y_axis_variable = elem_table.y_axis_variable + existing_elem.baseline_simulation_id = elem_table.baseline_simulation_id + existing_elem.reform_simulation_id = elem_table.reform_simulation_id + else: + # Create new element + elem_table = ReportElementTable.convert_from_model(element, database) + elem_table.report_id = model.id # Link to this report + elem_table.position = i # Maintain order + database.session.add(elem_table) + database.session.flush() + + return report_table + + def convert_to_model(self, database: "Database" = None) -> Report: + """Convert this ReportTable instance to a Report instance.""" + # Load nested report elements if database is provided + elements = [] + if database: + from .report_element_table import ReportElementTable + from sqlmodel import select + + # Query for all elements linked to this report, ordered by position + elem_tables = database.session.exec( + select(ReportElementTable) + .where(ReportElementTable.report_id == self.id) + .order_by(ReportElementTable.position) + ).all() + + # Convert each one to a model + for elem_table in elem_tables: + elements.append(elem_table.convert_to_model(database)) + + return Report( + id=self.id, + label=self.label, + created_at=self.created_at, + elements=elements, + ) + report_table_link = TableLink( model_cls=Report, diff --git a/src/policyengine/database/simulation_table.py b/src/policyengine/database/simulation_table.py index 483a78be..de45a419 100644 --- a/src/policyengine/database/simulation_table.py +++ b/src/policyengine/database/simulation_table.py @@ -1,13 +1,17 @@ from datetime import datetime +from typing import TYPE_CHECKING from uuid import uuid4 from sqlmodel import Field, SQLModel -from policyengine.models import Simulation +from policyengine.models import Dataset, Dynamic, Model, ModelVersion, Policy, Simulation from policyengine.utils.compress import compress_data, decompress_data from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class SimulationTable(SQLModel, table=True): __tablename__ = "simulations" @@ -30,21 +34,192 @@ class SimulationTable(SQLModel, table=True): result: bytes | None = Field(default=None) + @classmethod + def convert_from_model(cls, model: Simulation, database: "Database" = None) -> "SimulationTable": + """Convert a Simulation instance to a SimulationTable instance. + + Args: + model: The Simulation instance to convert + database: The database instance for persisting foreign objects if needed + + Returns: + A SimulationTable instance + """ + # Ensure all foreign objects are persisted if database is provided + if database: + if model.policy: + database.set(model.policy, commit=False) + if model.dynamic: + database.set(model.dynamic, commit=False) + if model.dataset: + database.set(model.dataset, commit=False) + if model.model: + database.set(model.model, commit=False) + if model.model_version: + database.set(model.model_version, commit=False) + + sim_table = cls( + id=model.id, + created_at=model.created_at, + updated_at=model.updated_at, + policy_id=model.policy.id if model.policy else None, + dynamic_id=model.dynamic.id if model.dynamic else None, + dataset_id=model.dataset.id if model.dataset else None, + model_id=model.model.id if model.model else None, + model_version_id=model.model_version.id if model.model_version else None, + result=compress_data(model.result) if model.result else None, + ) + + # Handle nested aggregates if database is provided + if database and model.aggregates: + from .aggregate import AggregateTable + from sqlmodel import select + + # First ensure the simulation table is saved to the database + # This is necessary so the foreign key constraint is satisfied + # Check if it already exists + existing_sim = database.session.exec( + select(SimulationTable).where(SimulationTable.id == model.id) + ).first() + + if not existing_sim: + database.session.add(sim_table) + database.session.flush() + + # Track which aggregate IDs we want to keep + desired_agg_ids = {agg.id for agg in model.aggregates} + + # Delete only aggregates linked to this simulation that are NOT in the new list + existing_aggs = database.session.exec( + select(AggregateTable).where(AggregateTable.simulation_id == model.id) + ).all() + for agg in existing_aggs: + if agg.id not in desired_agg_ids: + database.session.delete(agg) + + # Now save/update the aggregates + for aggregate in model.aggregates: + # Check if this aggregate already exists in the database + existing_agg = database.session.exec( + select(AggregateTable).where(AggregateTable.id == aggregate.id) + ).first() + + if existing_agg: + # Update existing aggregate + agg_table = AggregateTable.convert_from_model(aggregate, database) + existing_agg.simulation_id = agg_table.simulation_id + existing_agg.entity = agg_table.entity + existing_agg.variable_name = agg_table.variable_name + existing_agg.year = agg_table.year + existing_agg.filter_variable_name = agg_table.filter_variable_name + existing_agg.filter_variable_value = agg_table.filter_variable_value + existing_agg.filter_variable_leq = agg_table.filter_variable_leq + existing_agg.filter_variable_geq = agg_table.filter_variable_geq + existing_agg.aggregate_function = agg_table.aggregate_function + existing_agg.value = agg_table.value + else: + # Create new aggregate + agg_table = AggregateTable.convert_from_model(aggregate, database) + database.session.add(agg_table) + database.session.flush() + + return sim_table + + def convert_to_model(self, database: "Database" = None) -> Simulation: + """Convert this SimulationTable instance to a Simulation instance. + + Args: + database: The database instance for resolving foreign keys + + Returns: + A Simulation instance + """ + from sqlmodel import select + + from .model_version_table import ModelVersionTable + from .policy_table import PolicyTable + from .dataset_table import DatasetTable + from .model_table import ModelTable + from .dynamic_table import DynamicTable + + # Resolve all foreign keys + policy = None + dynamic = None + dataset = None + model = None + model_version = None + + if database: + if self.policy_id: + policy_table = database.session.exec( + select(PolicyTable).where(PolicyTable.id == self.policy_id) + ).first() + if policy_table: + policy = policy_table.convert_to_model(database) + + if self.dynamic_id: + try: + dynamic_table = database.session.exec( + select(DynamicTable).where(DynamicTable.id == self.dynamic_id) + ).first() + if dynamic_table: + dynamic = dynamic_table.convert_to_model(database) + except: + # Dynamic table might not be defined yet + dynamic = database.get(Dynamic, id=self.dynamic_id) + + if self.dataset_id: + dataset_table = database.session.exec( + select(DatasetTable).where(DatasetTable.id == self.dataset_id) + ).first() + if dataset_table: + dataset = dataset_table.convert_to_model(database) + + if self.model_id: + model_table = database.session.exec( + select(ModelTable).where(ModelTable.id == self.model_id) + ).first() + if model_table: + model = model_table.convert_to_model(database) + + if self.model_version_id: + version_table = database.session.exec( + select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) + ).first() + if version_table: + model_version = version_table.convert_to_model(database) + + # Load aggregates + aggregates = [] + if database: + from .aggregate import AggregateTable + from sqlmodel import select + + agg_tables = database.session.exec( + select(AggregateTable).where(AggregateTable.simulation_id == self.id) + ).all() + + for agg_table in agg_tables: + # Don't pass database to avoid circular reference issues + # The simulation reference will be set separately + agg_model = agg_table.convert_to_model(None) + aggregates.append(agg_model) + + return Simulation( + id=self.id, + created_at=self.created_at, + updated_at=self.updated_at, + policy=policy, + dynamic=dynamic, + dataset=dataset, + model=model, + model_version=model_version, + result=decompress_data(self.result) if self.result else None, + aggregates=aggregates, + ) + simulation_table_link = TableLink( model_cls=Simulation, table_cls=SimulationTable, - model_to_table_custom_transforms=dict( - policy_id=lambda s: s.policy.id if s.policy else None, - dynamic_id=lambda s: s.dynamic.id if s.dynamic else None, - dataset_id=lambda s: s.dataset.id, - model_id=lambda s: s.model.id, - model_version_id=lambda s: s.model_version.id - if s.model_version - else None, - result=lambda s: compress_data(s.result) if s.result else None, - ), - table_to_model_custom_transforms=dict( - result=lambda b: decompress_data(b) if b else None, - ), ) diff --git a/src/policyengine/database/table_mixin.py b/src/policyengine/database/table_mixin.py new file mode 100644 index 00000000..a29cdeb6 --- /dev/null +++ b/src/policyengine/database/table_mixin.py @@ -0,0 +1,80 @@ +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar + +from pydantic import BaseModel +from sqlmodel import SQLModel + +if TYPE_CHECKING: + from .database import Database + +T = TypeVar("T", bound=BaseModel) + + +class TableConversionMixin: + """Mixin class for SQLModel tables to provide conversion methods between table instances and Pydantic models.""" + + _model_cls: ClassVar[type[BaseModel]] = None + _foreign_key_fields: ClassVar[dict[str, type[BaseModel]]] = {} + + @classmethod + def convert_from_model(cls, model: BaseModel, database: "Database" = None) -> SQLModel: + """Convert a Pydantic model instance to a table instance, resolving foreign objects to IDs. + + Args: + model: The Pydantic model instance to convert + database: The database instance for resolving foreign objects (optional) + + Returns: + An instance of the SQLModel table class + """ + data = {} + + for field_name in cls.__annotations__.keys(): + # Check if this field is a foreign key that needs resolution + if field_name in cls._foreign_key_fields: + # Extract ID from the nested object + nested_obj = getattr(model, field_name.replace("_id", ""), None) + if nested_obj: + # If we need to ensure the foreign object exists in DB + if database: + database.set(nested_obj, commit=False) + data[field_name] = nested_obj.id if hasattr(nested_obj, "id") else None + else: + data[field_name] = None + elif hasattr(model, field_name): + # Direct field mapping + data[field_name] = getattr(model, field_name) + + return cls(**data) + + @classmethod + def convert_to_model(cls, table_instance: SQLModel, database: "Database" = None) -> BaseModel: + """Convert a table instance to a Pydantic model, resolving foreign key IDs to objects. + + Args: + table_instance: The SQLModel table instance to convert + database: The database instance for resolving foreign keys (required if foreign keys exist) + + Returns: + An instance of the Pydantic model class + """ + if cls._model_cls is None: + raise ValueError(f"Model class not set for {cls.__name__}") + + data = {} + + for field_name in cls._model_cls.__annotations__.keys(): + # Check if we need to resolve a foreign key + fk_field = f"{field_name}_id" + if fk_field in cls._foreign_key_fields and database: + # Resolve the foreign key to an object + fk_id = getattr(table_instance, fk_field, None) + if fk_id: + foreign_model_cls = cls._foreign_key_fields[fk_field] + data[field_name] = database.get(foreign_model_cls, id=fk_id) + else: + data[field_name] = None + elif hasattr(table_instance, field_name): + # Direct field mapping + data[field_name] = getattr(table_instance, field_name) + + return cls._model_cls(**data) \ No newline at end of file diff --git a/src/policyengine/database/user_table.py b/src/policyengine/database/user_table.py index 8c79f73a..d663ac8f 100644 --- a/src/policyengine/database/user_table.py +++ b/src/policyengine/database/user_table.py @@ -2,11 +2,15 @@ from datetime import datetime from sqlmodel import Field, SQLModel +from typing import TYPE_CHECKING from policyengine.models.user import User from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class UserTable(SQLModel, table=True, extend_existing=True): __tablename__ = "users" @@ -21,6 +25,31 @@ class UserTable(SQLModel, table=True, extend_existing=True): created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) + @classmethod + def convert_from_model(cls, model: User, database: "Database" = None) -> "UserTable": + """Convert a User instance to a UserTable instance.""" + return cls( + id=model.id, + username=model.username, + first_name=model.first_name, + last_name=model.last_name, + email=model.email, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + def convert_to_model(self, database: "Database" = None) -> User: + """Convert this UserTable instance to a User instance.""" + return User( + id=self.id, + username=self.username, + first_name=self.first_name, + last_name=self.last_name, + email=self.email, + created_at=self.created_at, + updated_at=self.updated_at, + ) + user_table_link = TableLink( model_cls=User, diff --git a/src/policyengine/database/versioned_dataset_table.py b/src/policyengine/database/versioned_dataset_table.py index 52aa207b..4e1524c9 100644 --- a/src/policyengine/database/versioned_dataset_table.py +++ b/src/policyengine/database/versioned_dataset_table.py @@ -1,11 +1,15 @@ from uuid import uuid4 from sqlmodel import Field, SQLModel +from typing import TYPE_CHECKING from policyengine.models import VersionedDataset from .link import TableLink +if TYPE_CHECKING: + from .database import Database + class VersionedDatasetTable(SQLModel, table=True): __tablename__ = "versioned_datasets" @@ -17,12 +21,25 @@ class VersionedDatasetTable(SQLModel, table=True): default=None, foreign_key="models.id", ondelete="SET NULL" ) + @classmethod + def convert_from_model(cls, model: VersionedDataset, database: "Database" = None) -> "VersionedDatasetTable": + """Convert a VersionedDataset instance to a VersionedDatasetTable instance.""" + return cls( + id=model.id, + name=model.name, + description=model.description, + ) + + def convert_to_model(self, database: "Database" = None) -> VersionedDataset: + """Convert this VersionedDatasetTable instance to a VersionedDataset instance.""" + return VersionedDataset( + id=self.id, + name=self.name, + description=self.description, + ) + versioned_dataset_table_link = TableLink( model_cls=VersionedDataset, table_cls=VersionedDatasetTable, - model_to_table_custom_transforms=dict( - model_id=lambda vd: vd.model.id if vd.model else None, - ), - table_to_model_custom_transforms={}, ) diff --git a/src/policyengine/models/__init__.py b/src/policyengine/models/__init__.py index b92592b9..652d46cd 100644 --- a/src/policyengine/models/__init__.py +++ b/src/policyengine/models/__init__.py @@ -28,3 +28,9 @@ from .simulation import Simulation as Simulation from .user import User as User from .versioned_dataset import VersionedDataset as VersionedDataset + +# Rebuild models to handle circular references +from .aggregate import Aggregate +from .simulation import Simulation +Aggregate.model_rebuild() +Simulation.model_rebuild() diff --git a/src/policyengine/models/aggregate.py b/src/policyengine/models/aggregate.py index b25d9d1a..86ebe996 100644 --- a/src/policyengine/models/aggregate.py +++ b/src/policyengine/models/aggregate.py @@ -1,9 +1,10 @@ from enum import Enum from typing import TYPE_CHECKING, Literal +from uuid import uuid4 import pandas as pd from microdf import MicroDataFrame -from pydantic import BaseModel +from pydantic import BaseModel, Field if TYPE_CHECKING: from policyengine.models import Simulation @@ -16,7 +17,8 @@ class AggregateType(str, Enum): class Aggregate(BaseModel): - simulation: "Simulation" + id: str = Field(default_factory=lambda: str(uuid4())) + simulation: "Simulation | None" = None entity: str variable_name: str year: int | None = None @@ -27,6 +29,7 @@ class Aggregate(BaseModel): aggregate_function: Literal[ AggregateType.SUM, AggregateType.MEAN, AggregateType.COUNT ] + reportelement_id: str | None = None value: float | None = None @@ -36,6 +39,8 @@ def run(aggregates: list["Aggregate"]) -> list["Aggregate"]: results = [] tables = aggregates[0].simulation.result + # copy tables to ensure we don't modify original dataframes + tables = {k: v.copy() for k, v in tables.items()} for table in tables: tables[table] = pd.DataFrame(tables[table]) weight_col = f"{table}_weight" diff --git a/src/policyengine/models/baseline_parameter_value.py b/src/policyengine/models/baseline_parameter_value.py index 65cd4aba..8afb6e22 100644 --- a/src/policyengine/models/baseline_parameter_value.py +++ b/src/policyengine/models/baseline_parameter_value.py @@ -1,12 +1,14 @@ from datetime import datetime +from uuid import uuid4 -from pydantic import BaseModel +from pydantic import BaseModel, Field from .model_version import ModelVersion from .parameter import Parameter class BaselineParameterValue(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) parameter: Parameter model_version: ModelVersion value: float | int | str | bool | list | None = None diff --git a/src/policyengine/models/model.py b/src/policyengine/models/model.py index e898f489..89cac9b8 100644 --- a/src/policyengine/models/model.py +++ b/src/policyengine/models/model.py @@ -45,6 +45,8 @@ def create_seed_objects(self, model_version): description=parameter.description, data_type=None, model=self, + label=parameter.metadata.get("label"), + unit=parameter.metadata.get("unit"), ) parameters.append(param) if isinstance(parameter, CoreParameter): diff --git a/src/policyengine/models/parameter.py b/src/policyengine/models/parameter.py index c438f4f6..ec7ef7be 100644 --- a/src/policyengine/models/parameter.py +++ b/src/policyengine/models/parameter.py @@ -10,3 +10,5 @@ class Parameter(BaseModel): description: str | None = None data_type: type | None = None model: Model | None = None + label: str | None = None + unit: str | None = None diff --git a/src/policyengine/models/policyengine_uk.py b/src/policyengine/models/policyengine_uk.py index 22b089a8..bb66ebe5 100644 --- a/src/policyengine/models/policyengine_uk.py +++ b/src/policyengine/models/policyengine_uk.py @@ -94,7 +94,8 @@ def simulation_modifier(sim: Microsimulation): if correct_entity: output_data[entity][variable.name] = sim.calculate( variable.name - ) + ).values + output_data[entity] = pd.DataFrame(output_data[entity]) return output_data diff --git a/src/policyengine/models/policyengine_us.py b/src/policyengine/models/policyengine_us.py index 8886f0b8..9e2eeb7d 100644 --- a/src/policyengine/models/policyengine_us.py +++ b/src/policyengine/models/policyengine_us.py @@ -94,7 +94,7 @@ def simulation_modifier(sim: Microsimulation): continue if not correct_entity: continue - output_data[entity][variable.name] = sim.calculate(variable.name) + output_data[entity][variable.name] = sim.calculate(variable.name).values return output_data diff --git a/src/policyengine/models/report.py b/src/policyengine/models/report.py index 6a6442b3..2ae0cd3b 100644 --- a/src/policyengine/models/report.py +++ b/src/policyengine/models/report.py @@ -1,10 +1,20 @@ import uuid from datetime import datetime +from typing import TYPE_CHECKING, ForwardRef from pydantic import BaseModel, Field +if TYPE_CHECKING: + from .report_element import ReportElement + class Report(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) label: str created_at: datetime | None = None + elements: list[ForwardRef("ReportElement")] = Field(default_factory=list) + + +# Import after class definition to avoid circular import +from .report_element import ReportElement +Report.model_rebuild() diff --git a/src/policyengine/models/simulation.py b/src/policyengine/models/simulation.py index 8993ebe6..a6ed7a5a 100644 --- a/src/policyengine/models/simulation.py +++ b/src/policyengine/models/simulation.py @@ -23,6 +23,7 @@ class Simulation(BaseModel): model: Model model_version: ModelVersion result: Any | None = None + aggregates: list = Field(default_factory=list) # Will be list[Aggregate] but avoid circular import def run(self): self.result = self.model.simulation_function( From 93133a1e339fc23ce8429c4a2d1b448924424306 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 29 Sep 2025 09:19:13 +0100 Subject: [PATCH 02/35] Add fields --- docs/quickstart.ipynb | 42 +++++++++---------- .../database/baseline_variable_table.py | 18 ++++---- .../database/report_element_table.py | 11 ++++- src/policyengine/models/aggregate.py | 40 ++++++++++++++++-- src/policyengine/models/policyengine_uk.py | 2 - src/policyengine/models/report_element.py | 2 + 6 files changed, 80 insertions(+), 35 deletions(-) diff --git a/docs/quickstart.ipynb b/docs/quickstart.ipynb index 3932cd2a..46c7bdfd 100644 --- a/docs/quickstart.ipynb +++ b/docs/quickstart.ipynb @@ -66,13 +66,13 @@ "£500,000" ], "y": [ - 6601567.352745674, - 10308068.997645264, - 7214847.8757442, - 4260541.368792315, - 1683256.5811573816, - 1320122.584893554, - 326076.2586429423, + 6604006.784030474, + 10307871.58979292, + 7152632.348702732, + 4284865.771267385, + 1718930.0846310211, + 1320096.830079406, + 326077.61111739336, 187608.23132836912, 63106.63353048405, 41838.373842805624 @@ -1110,13 +1110,13 @@ "£500,000" ], "y": [ - 6601567.352745674, - 10308068.997645264, - 7214847.8757442, - 4260541.368792315, - 1683256.5811573816, - 1320122.584893554, - 326076.2586429423, + 6604006.784030474, + 10307871.58979292, + 7152632.348702732, + 4284865.771267385, + 1718930.0846310211, + 1320096.830079406, + 326077.61111739336, 187608.23132836912, 63106.63353048405, 41838.373842805624 @@ -1148,12 +1148,12 @@ "£500,000" ], "y": [ - 6146242.297314803, - 10316900.228889126, - 6968299.661946169, - 4446106.99205335, - 1992976.835352243, - 1472301.8444251162, + 6148681.728599603, + 10315633.853035238, + 6911688.091380573, + 4489387.35271926, + 2005084.3325594077, + 1472352.5016867262, 341808.24952948757, 218180.35939976107, 63106.63353048405, @@ -2224,7 +2224,7 @@ { "data": { "text/plain": [ - "Policy(id='64792d73-b52a-431b-a22a-1b5cf3ae4551', name='Increase personal allowance to £20,000', description='A policy to increase the personal allowance for income tax to £20,000.', parameter_values=[ParameterValue(id='5ffda0b7-c70f-4cef-b45b-91fb663d12db', parameter=Parameter(id='gov.hmrc.income_tax.allowances.personal_allowance.amount', description=None, data_type=None, model=Model(id='policyengine_uk', name='PolicyEngine UK', description=\"PolicyEngine's open-source tax-benefit microsimulation model.\", simulation_function=), label=None, unit=None), value=20000, start_date=datetime.datetime(2029, 1, 1, 0, 0), end_date=None)], simulation_modifier=None, created_at=datetime.datetime(2025, 9, 28, 19, 57, 41, 515173), updated_at=datetime.datetime(2025, 9, 28, 19, 57, 41, 515176))" + "Policy(id='9e558aa7-392f-4203-bf7d-769929e4f01e', name='Increase personal allowance to £20,000', description='A policy to increase the personal allowance for income tax to £20,000.', parameter_values=[ParameterValue(id='722765c7-5a97-4726-8156-a8f8fb26cb09', parameter=Parameter(id='gov.hmrc.income_tax.allowances.personal_allowance.amount', description=None, data_type=None, model=Model(id='policyengine_uk', name='PolicyEngine UK', description=\"PolicyEngine's open-source tax-benefit microsimulation model.\", simulation_function=), label=None, unit=None), value=20000, start_date=datetime.datetime(2029, 1, 1, 0, 0), end_date=None)], simulation_modifier=None, created_at=datetime.datetime(2025, 9, 28, 23, 42, 32, 430810), updated_at=datetime.datetime(2025, 9, 28, 23, 42, 32, 430842))" ] }, "execution_count": 5, diff --git a/src/policyengine/database/baseline_variable_table.py b/src/policyengine/database/baseline_variable_table.py index 6d836ee2..e7773c80 100644 --- a/src/policyengine/database/baseline_variable_table.py +++ b/src/policyengine/database/baseline_variable_table.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING from policyengine.models import ModelVersion, BaselineVariable -from policyengine.utils.compress import compress_data, decompress_data from .link import TableLink @@ -24,13 +23,11 @@ class BaselineVariableTable(SQLModel, table=True): entity: str = Field(nullable=False) label: str | None = Field(default=None) description: str | None = Field(default=None) - data_type: bytes | None = Field(default=None) # Pickled type + data_type: str | None = Field(default=None) # Data type name @classmethod def convert_from_model(cls, model: BaselineVariable, database: "Database" = None) -> "BaselineVariableTable": """Convert a BaselineVariable instance to a BaselineVariableTable instance.""" - from policyengine.utils.compress import compress_data - # Ensure foreign objects are persisted if database is provided if database and model.model_version: database.set(model.model_version, commit=False) @@ -42,12 +39,11 @@ def convert_from_model(cls, model: BaselineVariable, database: "Database" = None entity=model.entity, label=model.label, description=model.description, - data_type=compress_data(model.data_type) if model.data_type else None, + data_type=model.data_type.__name__ if model.data_type else None, ) def convert_to_model(self, database: "Database" = None) -> BaselineVariable: """Convert this BaselineVariableTable instance to a BaselineVariable instance.""" - from policyengine.utils.compress import decompress_data from .model_version_table import ModelVersionTable from sqlmodel import select @@ -61,13 +57,21 @@ def convert_to_model(self, database: "Database" = None) -> BaselineVariable: if version_table: model_version = version_table.convert_to_model(database) + # Convert data_type string back to type + data_type = None + if self.data_type: + try: + data_type = eval(self.data_type) + except: + data_type = None + return BaselineVariable( id=self.id, model_version=model_version, entity=self.entity, label=self.label, description=self.description, - data_type=decompress_data(self.data_type) if self.data_type else None, + data_type=data_type, ) diff --git a/src/policyengine/database/report_element_table.py b/src/policyengine/database/report_element_table.py index 3db4e481..86e57f9b 100644 --- a/src/policyengine/database/report_element_table.py +++ b/src/policyengine/database/report_element_table.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime -from sqlmodel import Field, SQLModel +from sqlmodel import Field, SQLModel, Column, JSON from typing import TYPE_CHECKING from policyengine.models.report_element import ReportElement @@ -40,8 +40,11 @@ class ReportElementTable(SQLModel, table=True, extend_existing=True): # Metadata report_id: str | None = Field(default=None, foreign_key="reports.id") user_id: str | None = Field(default=None, foreign_key="users.id") + model_version_id: str | None = Field(default=None, foreign_key="model_versions.id") position: int | None = Field(default=None) visible: bool | None = Field(default=True) + custom_config: dict | None = Field(default=None, sa_column=Column(JSON)) + report_element_metadata: dict | None = Field(default=None, sa_column=Column(JSON)) created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) @@ -62,8 +65,11 @@ def convert_from_model(cls, model: ReportElement, database: "Database" = None) - markdown_content=model.markdown_content, report_id=model.report_id, user_id=model.user_id, + model_version_id=model.model_version_id, position=model.position, visible=model.visible, + custom_config=model.custom_config, + report_element_metadata=model.report_element_metadata, created_at=model.created_at, updated_at=model.updated_at, ) @@ -84,8 +90,11 @@ def convert_to_model(self, database: "Database" = None) -> ReportElement: markdown_content=self.markdown_content, report_id=self.report_id, user_id=self.user_id, + model_version_id=self.model_version_id, position=self.position, visible=self.visible, + custom_config=self.custom_config, + report_element_metadata=self.report_element_metadata, created_at=self.created_at, updated_at=self.updated_at, ) diff --git a/src/policyengine/models/aggregate.py b/src/policyengine/models/aggregate.py index 86ebe996..5ae67569 100644 --- a/src/policyengine/models/aggregate.py +++ b/src/policyengine/models/aggregate.py @@ -35,11 +35,43 @@ class Aggregate(BaseModel): @staticmethod def run(aggregates: list["Aggregate"]) -> list["Aggregate"]: - # Assumes that all aggregates are from the same simulation + """Process aggregates, handling multiple simulations if necessary.""" + # Group aggregates by simulation + simulation_groups = {} + for agg in aggregates: + sim_id = id(agg.simulation) if agg.simulation else None + if sim_id not in simulation_groups: + simulation_groups[sim_id] = [] + simulation_groups[sim_id].append(agg) + + # Process each simulation group separately + all_results = [] + for sim_id, sim_aggregates in simulation_groups.items(): + if not sim_aggregates: + continue + + # Get the simulation from the first aggregate in this group + simulation = sim_aggregates[0].simulation + if simulation is None: + raise ValueError("Aggregate has no simulation attached") + + # Process this simulation's aggregates + group_results = Aggregate._process_simulation_aggregates( + sim_aggregates, simulation + ) + all_results.extend(group_results) + + return all_results + + @staticmethod + def _process_simulation_aggregates( + aggregates: list["Aggregate"], simulation: "Simulation" + ) -> list["Aggregate"]: + """Process aggregates for a single simulation.""" results = [] - tables = aggregates[0].simulation.result - # copy tables to ensure we don't modify original dataframes + tables = simulation.result + # Copy tables to ensure we don't modify original dataframes tables = {k: v.copy() for k, v in tables.items()} for table in tables: tables[table] = pd.DataFrame(tables[table]) @@ -64,7 +96,7 @@ def run(aggregates: list["Aggregate"]) -> list["Aggregate"]: df = table if agg.year is None: - agg.year = aggregates[0].simulation.dataset.year + agg.year = simulation.dataset.year if agg.filter_variable_name is not None: if agg.filter_variable_name not in df.columns: diff --git a/src/policyengine/models/policyengine_uk.py b/src/policyengine/models/policyengine_uk.py index bb66ebe5..5b97ccfb 100644 --- a/src/policyengine/models/policyengine_uk.py +++ b/src/policyengine/models/policyengine_uk.py @@ -57,8 +57,6 @@ def simulation_modifier(sim: Microsimulation): simulation_modifier(sim) - # Skip reforms for now - output_data = {} variable_blacklist = [ # TEMPORARY: we need to fix policyengine-uk to make these only take a long time with non-default parameters set to true. diff --git a/src/policyengine/models/report_element.py b/src/policyengine/models/report_element.py index 63055fd7..ac7fcfcb 100644 --- a/src/policyengine/models/report_element.py +++ b/src/policyengine/models/report_element.py @@ -29,8 +29,10 @@ class ReportElement(BaseModel): # Metadata report_id: str | None = None user_id: str | None = None + model_version_id: str | None = None position: int | None = None visible: bool | None = True custom_config: dict | None = None # Additional chart-specific config + report_element_metadata: dict | None = None # General metadata field for flexible data storage created_at: datetime | None = None updated_at: datetime | None = None From b79f4f400523a5ceb8fccd07bb944967f80be625 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 29 Sep 2025 18:24:11 +0100 Subject: [PATCH 03/35] Update --- build/lib/policyengine/__init__.py | 0 build/lib/policyengine/database/__init__.py | 68 ++++ build/lib/policyengine/database/aggregate.py | 101 ++++++ .../policyengine/database/aggregate_change.py | 122 +++++++ .../baseline_parameter_value_table.py | 112 +++++++ .../database/baseline_variable_table.py | 81 +++++ build/lib/policyengine/database/database.py | 301 ++++++++++++++++++ .../policyengine/database/dataset_table.py | 94 ++++++ .../policyengine/database/dynamic_table.py | 68 ++++ build/lib/policyengine/database/link.py | 8 + .../lib/policyengine/database/model_table.py | 60 ++++ .../database/model_version_table.py | 73 +++++ .../policyengine/database/parameter_table.py | 92 ++++++ .../database/parameter_value_table.py | 108 +++++++ .../lib/policyengine/database/policy_table.py | 136 ++++++++ .../database/report_element_table.py | 106 ++++++ .../lib/policyengine/database/report_table.py | 120 +++++++ .../policyengine/database/simulation_table.py | 225 +++++++++++++ .../lib/policyengine/database/table_mixin.py | 80 +++++ build/lib/policyengine/database/user_table.py | 57 ++++ .../database/versioned_dataset_table.py | 45 +++ build/lib/policyengine/models/__init__.py | 39 +++ build/lib/policyengine/models/aggregate.py | 132 ++++++++ .../policyengine/models/aggregate_change.py | 143 +++++++++ .../models/baseline_parameter_value.py | 16 + .../policyengine/models/baseline_variable.py | 12 + build/lib/policyengine/models/dataset.py | 18 ++ build/lib/policyengine/models/dynamic.py | 15 + build/lib/policyengine/models/model.py | 126 ++++++++ .../lib/policyengine/models/model_version.py | 14 + build/lib/policyengine/models/parameter.py | 14 + .../policyengine/models/parameter_value.py | 14 + build/lib/policyengine/models/policy.py | 17 + .../policyengine/models/policyengine_uk.py | 113 +++++++ .../policyengine/models/policyengine_us.py | 115 +++++++ build/lib/policyengine/models/report.py | 20 ++ .../lib/policyengine/models/report_element.py | 38 +++ build/lib/policyengine/models/simulation.py | 35 ++ build/lib/policyengine/models/user.py | 14 + .../policyengine/models/versioned_dataset.py | 12 + build/lib/policyengine/utils/charts.py | 286 +++++++++++++++++ build/lib/policyengine/utils/compress.py | 20 ++ build/lib/policyengine/utils/datasets.py | 71 +++++ docs/quickstart.ipynb | 36 +-- src/policyengine/database/__init__.py | 3 + src/policyengine/database/aggregate.py | 9 + src/policyengine/database/aggregate_change.py | 128 ++++++++ .../database/report_element_table.py | 2 +- src/policyengine/models/__init__.py | 3 + src/policyengine/models/aggregate.py | 5 +- src/policyengine/models/aggregate_change.py | 143 +++++++++ src/policyengine/models/report_element.py | 2 +- 52 files changed, 3651 insertions(+), 21 deletions(-) create mode 100644 build/lib/policyengine/__init__.py create mode 100644 build/lib/policyengine/database/__init__.py create mode 100644 build/lib/policyengine/database/aggregate.py create mode 100644 build/lib/policyengine/database/aggregate_change.py create mode 100644 build/lib/policyengine/database/baseline_parameter_value_table.py create mode 100644 build/lib/policyengine/database/baseline_variable_table.py create mode 100644 build/lib/policyengine/database/database.py create mode 100644 build/lib/policyengine/database/dataset_table.py create mode 100644 build/lib/policyengine/database/dynamic_table.py create mode 100644 build/lib/policyengine/database/link.py create mode 100644 build/lib/policyengine/database/model_table.py create mode 100644 build/lib/policyengine/database/model_version_table.py create mode 100644 build/lib/policyengine/database/parameter_table.py create mode 100644 build/lib/policyengine/database/parameter_value_table.py create mode 100644 build/lib/policyengine/database/policy_table.py create mode 100644 build/lib/policyengine/database/report_element_table.py create mode 100644 build/lib/policyengine/database/report_table.py create mode 100644 build/lib/policyengine/database/simulation_table.py create mode 100644 build/lib/policyengine/database/table_mixin.py create mode 100644 build/lib/policyengine/database/user_table.py create mode 100644 build/lib/policyengine/database/versioned_dataset_table.py create mode 100644 build/lib/policyengine/models/__init__.py create mode 100644 build/lib/policyengine/models/aggregate.py create mode 100644 build/lib/policyengine/models/aggregate_change.py create mode 100644 build/lib/policyengine/models/baseline_parameter_value.py create mode 100644 build/lib/policyengine/models/baseline_variable.py create mode 100644 build/lib/policyengine/models/dataset.py create mode 100644 build/lib/policyengine/models/dynamic.py create mode 100644 build/lib/policyengine/models/model.py create mode 100644 build/lib/policyengine/models/model_version.py create mode 100644 build/lib/policyengine/models/parameter.py create mode 100644 build/lib/policyengine/models/parameter_value.py create mode 100644 build/lib/policyengine/models/policy.py create mode 100644 build/lib/policyengine/models/policyengine_uk.py create mode 100644 build/lib/policyengine/models/policyengine_us.py create mode 100644 build/lib/policyengine/models/report.py create mode 100644 build/lib/policyengine/models/report_element.py create mode 100644 build/lib/policyengine/models/simulation.py create mode 100644 build/lib/policyengine/models/user.py create mode 100644 build/lib/policyengine/models/versioned_dataset.py create mode 100644 build/lib/policyengine/utils/charts.py create mode 100644 build/lib/policyengine/utils/compress.py create mode 100644 build/lib/policyengine/utils/datasets.py create mode 100644 src/policyengine/database/aggregate_change.py create mode 100644 src/policyengine/models/aggregate_change.py diff --git a/build/lib/policyengine/__init__.py b/build/lib/policyengine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/policyengine/database/__init__.py b/build/lib/policyengine/database/__init__.py new file mode 100644 index 00000000..88e1a21b --- /dev/null +++ b/build/lib/policyengine/database/__init__.py @@ -0,0 +1,68 @@ +from .baseline_parameter_value_table import ( + BaselineParameterValueTable, + baseline_parameter_value_table_link, +) +from .baseline_variable_table import ( + BaselineVariableTable, + baseline_variable_table_link, +) +from .database import Database +from .dataset_table import DatasetTable, dataset_table_link +from .dynamic_table import DynamicTable, dynamic_table_link +from .link import TableLink + +# Import all table classes and links +from .model_table import ModelTable, model_table_link +from .model_version_table import ModelVersionTable, model_version_table_link +from .parameter_table import ParameterTable, parameter_table_link +from .parameter_value_table import ( + ParameterValueTable, + parameter_value_table_link, +) +from .policy_table import PolicyTable, policy_table_link +from .simulation_table import SimulationTable, simulation_table_link +from .versioned_dataset_table import ( + VersionedDatasetTable, + versioned_dataset_table_link, +) +from .report_table import ReportTable, report_table_link +from .report_element_table import ReportElementTable, report_element_table_link +from .aggregate import AggregateTable, aggregate_table_link +from .aggregate_change import AggregateChangeTable, aggregate_change_table_link + +__all__ = [ + "Database", + "TableLink", + # Tables + "ModelTable", + "ModelVersionTable", + "DatasetTable", + "VersionedDatasetTable", + "PolicyTable", + "DynamicTable", + "ParameterTable", + "ParameterValueTable", + "BaselineParameterValueTable", + "BaselineVariableTable", + "SimulationTable", + "ReportTable", + "ReportElementTable", + "AggregateTable", + "AggregateChangeTable", + # Links + "model_table_link", + "model_version_table_link", + "dataset_table_link", + "versioned_dataset_table_link", + "policy_table_link", + "dynamic_table_link", + "parameter_table_link", + "parameter_value_table_link", + "baseline_parameter_value_table_link", + "baseline_variable_table_link", + "simulation_table_link", + "report_table_link", + "report_element_table_link", + "aggregate_table_link", + "aggregate_change_table_link", +] diff --git a/build/lib/policyengine/database/aggregate.py b/build/lib/policyengine/database/aggregate.py new file mode 100644 index 00000000..44c8aacd --- /dev/null +++ b/build/lib/policyengine/database/aggregate.py @@ -0,0 +1,101 @@ +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine.database.link import TableLink +from policyengine.models.aggregate import Aggregate +from policyengine.models import Simulation + +if TYPE_CHECKING: + from .database import Database + + +class AggregateTable(SQLModel, table=True): + __tablename__ = "aggregates" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + simulation_id: str = Field( + foreign_key="simulations.id", ondelete="CASCADE" + ) + entity: str + variable_name: str + year: int | None = None + filter_variable_name: str | None = None + filter_variable_value: str | None = None + filter_variable_leq: float | None = None + filter_variable_geq: float | None = None + aggregate_function: str + reportelement_id: str | None = None + value: float | None = None + + @classmethod + def convert_from_model(cls, model: Aggregate, database: "Database" = None) -> "AggregateTable": + """Convert an Aggregate instance to an AggregateTable instance. + + Args: + model: The Aggregate instance to convert + database: The database instance for persisting the simulation if needed + + Returns: + An AggregateTable instance + """ + # Don't try to save the simulation here - it's already being saved + # This prevents circular references + + return cls( + id=model.id, + simulation_id=model.simulation.id if model.simulation else None, + entity=model.entity, + variable_name=model.variable_name, + year=model.year, + filter_variable_name=model.filter_variable_name, + filter_variable_value=model.filter_variable_value, + filter_variable_leq=model.filter_variable_leq, + filter_variable_geq=model.filter_variable_geq, + aggregate_function=model.aggregate_function, + reportelement_id=model.reportelement_id, + value=model.value, + ) + + def convert_to_model(self, database: "Database" = None) -> Aggregate: + """Convert this AggregateTable instance to an Aggregate instance. + + Args: + database: The database instance for resolving the simulation foreign key + + Returns: + An Aggregate instance + """ + from .simulation_table import SimulationTable + from sqlmodel import select + + # Resolve the simulation foreign key + simulation = None + if database and self.simulation_id: + sim_table = database.session.exec( + select(SimulationTable).where(SimulationTable.id == self.simulation_id) + ).first() + if sim_table: + simulation = sim_table.convert_to_model(database) + + return Aggregate( + id=self.id, + simulation=simulation, + entity=self.entity, + variable_name=self.variable_name, + year=self.year, + filter_variable_name=self.filter_variable_name, + filter_variable_value=self.filter_variable_value, + filter_variable_leq=self.filter_variable_leq, + filter_variable_geq=self.filter_variable_geq, + aggregate_function=self.aggregate_function, + reportelement_id=self.reportelement_id, + value=self.value, + ) + + +aggregate_table_link = TableLink( + model_cls=Aggregate, + table_cls=AggregateTable, +) diff --git a/build/lib/policyengine/database/aggregate_change.py b/build/lib/policyengine/database/aggregate_change.py new file mode 100644 index 00000000..1011ffcc --- /dev/null +++ b/build/lib/policyengine/database/aggregate_change.py @@ -0,0 +1,122 @@ +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine.database.link import TableLink +from policyengine.models.aggregate_change import AggregateChange + +if TYPE_CHECKING: + from .database import Database + + +class AggregateChangeTable(SQLModel, table=True): + __tablename__ = "aggregate_changes" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + baseline_simulation_id: str = Field( + foreign_key="simulations.id", ondelete="CASCADE" + ) + comparison_simulation_id: str = Field( + foreign_key="simulations.id", ondelete="CASCADE" + ) + entity: str + variable_name: str + year: int | None = None + filter_variable_name: str | None = None + filter_variable_value: str | None = None + filter_variable_leq: float | None = None + filter_variable_geq: float | None = None + aggregate_function: str + reportelement_id: str | None = None + + baseline_value: float | None = None + comparison_value: float | None = None + change: float | None = None + relative_change: float | None = None + + @classmethod + def convert_from_model(cls, model: AggregateChange, database: "Database" = None) -> "AggregateChangeTable": + """Convert an AggregateChange instance to an AggregateChangeTable instance. + + Args: + model: The AggregateChange instance to convert + database: The database instance for persisting the simulations if needed + + Returns: + An AggregateChangeTable instance + """ + return cls( + id=model.id, + baseline_simulation_id=model.baseline_simulation.id if model.baseline_simulation else None, + comparison_simulation_id=model.comparison_simulation.id if model.comparison_simulation else None, + entity=model.entity, + variable_name=model.variable_name, + year=model.year, + filter_variable_name=model.filter_variable_name, + filter_variable_value=model.filter_variable_value, + filter_variable_leq=model.filter_variable_leq, + filter_variable_geq=model.filter_variable_geq, + aggregate_function=model.aggregate_function, + reportelement_id=model.reportelement_id, + baseline_value=model.baseline_value, + comparison_value=model.comparison_value, + change=model.change, + relative_change=model.relative_change, + ) + + def convert_to_model(self, database: "Database" = None) -> AggregateChange: + """Convert this AggregateChangeTable instance to an AggregateChange instance. + + Args: + database: The database instance for resolving simulation foreign keys + + Returns: + An AggregateChange instance + """ + from .simulation_table import SimulationTable + from sqlmodel import select + + # Resolve the simulation foreign keys + baseline_simulation = None + comparison_simulation = None + + if database: + if self.baseline_simulation_id: + sim_table = database.session.exec( + select(SimulationTable).where(SimulationTable.id == self.baseline_simulation_id) + ).first() + if sim_table: + baseline_simulation = sim_table.convert_to_model(database) + + if self.comparison_simulation_id: + sim_table = database.session.exec( + select(SimulationTable).where(SimulationTable.id == self.comparison_simulation_id) + ).first() + if sim_table: + comparison_simulation = sim_table.convert_to_model(database) + + return AggregateChange( + id=self.id, + baseline_simulation=baseline_simulation, + comparison_simulation=comparison_simulation, + entity=self.entity, + variable_name=self.variable_name, + year=self.year, + filter_variable_name=self.filter_variable_name, + filter_variable_value=self.filter_variable_value, + filter_variable_leq=self.filter_variable_leq, + filter_variable_geq=self.filter_variable_geq, + aggregate_function=self.aggregate_function, + reportelement_id=self.reportelement_id, + baseline_value=self.baseline_value, + comparison_value=self.comparison_value, + change=self.change, + relative_change=self.relative_change, + ) + + +aggregate_change_table_link = TableLink( + model_cls=AggregateChange, + table_cls=AggregateChangeTable, +) \ No newline at end of file diff --git a/build/lib/policyengine/database/baseline_parameter_value_table.py b/build/lib/policyengine/database/baseline_parameter_value_table.py new file mode 100644 index 00000000..6485223c --- /dev/null +++ b/build/lib/policyengine/database/baseline_parameter_value_table.py @@ -0,0 +1,112 @@ +from datetime import datetime +from typing import Any +from uuid import uuid4 + +from sqlmodel import JSON, Column, Field, SQLModel +from typing import TYPE_CHECKING + +from policyengine.models import ModelVersion, Parameter, BaselineParameterValue + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class BaselineParameterValueTable(SQLModel, table=True): + __tablename__ = "baseline_parameter_values" + __table_args__ = ({"extend_existing": True},) + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + parameter_id: str = Field(nullable=False) # Part of composite foreign key + model_id: str = Field(nullable=False) # Part of composite foreign key + model_version_id: str = Field( + foreign_key="model_versions.id", ondelete="CASCADE" + ) + value: Any | None = Field( + default=None, sa_column=Column(JSON) + ) # JSON field for any type + start_date: datetime = Field(nullable=False) + end_date: datetime | None = Field(default=None) + + @classmethod + def convert_from_model(cls, model: BaselineParameterValue, database: "Database" = None) -> "BaselineParameterValueTable": + """Convert a BaselineParameterValue instance to a BaselineParameterValueTable instance.""" + import math + + # Ensure foreign objects are persisted if database is provided + if database: + if model.parameter: + database.set(model.parameter, commit=False) + if model.model_version: + database.set(model.model_version, commit=False) + + # Handle special float values + value = model.value + if isinstance(value, float): + if math.isinf(value): + value = "Infinity" if value > 0 else "-Infinity" + elif math.isnan(value): + value = "NaN" + + return cls( + id=model.id, + parameter_id=model.parameter.id if model.parameter else None, + model_id=model.parameter.model.id if model.parameter and model.parameter.model else None, + model_version_id=model.model_version.id if model.model_version else None, + value=value, + start_date=model.start_date, + end_date=model.end_date, + ) + + def convert_to_model(self, database: "Database" = None) -> BaselineParameterValue: + """Convert this BaselineParameterValueTable instance to a BaselineParameterValue instance.""" + from .parameter_table import ParameterTable + from .model_version_table import ModelVersionTable + from sqlmodel import select + + # Resolve foreign keys + parameter = None + model_version = None + + if database: + if self.parameter_id and self.model_id: + param_table = database.session.exec( + select(ParameterTable).where( + ParameterTable.id == self.parameter_id, + ParameterTable.model_id == self.model_id + ) + ).first() + if param_table: + parameter = param_table.convert_to_model(database) + + if self.model_version_id: + version_table = database.session.exec( + select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) + ).first() + if version_table: + model_version = version_table.convert_to_model(database) + + # Handle special string values + value = self.value + if value == "Infinity": + value = float("inf") + elif value == "-Infinity": + value = float("-inf") + elif value == "NaN": + value = float("nan") + + return BaselineParameterValue( + id=self.id, + parameter=parameter, + model_version=model_version, + value=value, + start_date=self.start_date, + end_date=self.end_date, + ) + + +baseline_parameter_value_table_link = TableLink( + model_cls=BaselineParameterValue, + table_cls=BaselineParameterValueTable, +) diff --git a/build/lib/policyengine/database/baseline_variable_table.py b/build/lib/policyengine/database/baseline_variable_table.py new file mode 100644 index 00000000..e7773c80 --- /dev/null +++ b/build/lib/policyengine/database/baseline_variable_table.py @@ -0,0 +1,81 @@ +from sqlmodel import Field, SQLModel +from typing import TYPE_CHECKING + +from policyengine.models import ModelVersion, BaselineVariable + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class BaselineVariableTable(SQLModel, table=True): + __tablename__ = "baseline_variables" + __table_args__ = ({"extend_existing": True},) + + id: str = Field(primary_key=True) # Variable name + model_id: str = Field( + primary_key=True, foreign_key="models.id" + ) # Part of composite key + model_version_id: str = Field( + foreign_key="model_versions.id", ondelete="CASCADE" + ) + entity: str = Field(nullable=False) + label: str | None = Field(default=None) + description: str | None = Field(default=None) + data_type: str | None = Field(default=None) # Data type name + + @classmethod + def convert_from_model(cls, model: BaselineVariable, database: "Database" = None) -> "BaselineVariableTable": + """Convert a BaselineVariable instance to a BaselineVariableTable instance.""" + # Ensure foreign objects are persisted if database is provided + if database and model.model_version: + database.set(model.model_version, commit=False) + + return cls( + id=model.id, + model_id=model.model_version.model.id if model.model_version and model.model_version.model else None, + model_version_id=model.model_version.id if model.model_version else None, + entity=model.entity, + label=model.label, + description=model.description, + data_type=model.data_type.__name__ if model.data_type else None, + ) + + def convert_to_model(self, database: "Database" = None) -> BaselineVariable: + """Convert this BaselineVariableTable instance to a BaselineVariable instance.""" + from .model_version_table import ModelVersionTable + from sqlmodel import select + + # Resolve foreign keys + model_version = None + + if database and self.model_version_id: + version_table = database.session.exec( + select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) + ).first() + if version_table: + model_version = version_table.convert_to_model(database) + + # Convert data_type string back to type + data_type = None + if self.data_type: + try: + data_type = eval(self.data_type) + except: + data_type = None + + return BaselineVariable( + id=self.id, + model_version=model_version, + entity=self.entity, + label=self.label, + description=self.description, + data_type=data_type, + ) + + +baseline_variable_table_link = TableLink( + model_cls=BaselineVariable, + table_cls=BaselineVariableTable, +) diff --git a/build/lib/policyengine/database/database.py b/build/lib/policyengine/database/database.py new file mode 100644 index 00000000..2ae77e1c --- /dev/null +++ b/build/lib/policyengine/database/database.py @@ -0,0 +1,301 @@ +from typing import Any + +from sqlmodel import Session, SQLModel + +from .aggregate import aggregate_table_link +from .baseline_parameter_value_table import baseline_parameter_value_table_link +from .baseline_variable_table import baseline_variable_table_link +from .dataset_table import dataset_table_link +from .dynamic_table import dynamic_table_link +from .link import TableLink + +# Import all table links +from .model_table import model_table_link +from .model_version_table import model_version_table_link +from .parameter_table import parameter_table_link +from .parameter_value_table import parameter_value_table_link +from .policy_table import policy_table_link +from .report_element_table import report_element_table_link +from .report_table import report_table_link +from .simulation_table import simulation_table_link +from .user_table import user_table_link +from .versioned_dataset_table import versioned_dataset_table_link + + +class Database: + url: str + + _model_table_links: list[TableLink] = [] + + def __init__(self, url: str): + self.url = url + self.engine = self._create_engine() + self.session = Session(self.engine) + + for link in [ + model_table_link, + model_version_table_link, + dataset_table_link, + versioned_dataset_table_link, + policy_table_link, + dynamic_table_link, + parameter_table_link, + parameter_value_table_link, + baseline_parameter_value_table_link, + baseline_variable_table_link, + simulation_table_link, + aggregate_table_link, + user_table_link, + report_table_link, + report_element_table_link, + ]: + self.register_table(link) + + def _create_engine(self): + from sqlmodel import create_engine + + return create_engine(self.url, echo=False) + + def create_tables(self): + """Create all database tables.""" + SQLModel.metadata.create_all(self.engine) + + def drop_tables(self): + """Drop all database tables.""" + SQLModel.metadata.drop_all(self.engine) + + def reset(self): + """Drop and recreate all tables.""" + self.drop_tables() + self.create_tables() + + def __enter__(self): + """Context manager entry - creates a session.""" + self.session = Session(self.engine) + return self.session + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - closes the session.""" + if exc_type: + self.session.rollback() + else: + self.session.commit() + self.session.close() + + def register_table(self, link: TableLink): + self._model_table_links.append(link) + # Create the table if not exists + link.table_cls.metadata.create_all(self.engine) + + def get(self, model_cls: type, **kwargs): + """Get a model instance from the database by its attributes.""" + from sqlmodel import select + + # Find the table class for this model + table_link = next( + ( + link + for link in self._model_table_links + if link.model_cls == model_cls + ), + None, + ) + + if table_link is None: + return None + + # Query the database + statement = select(table_link.table_cls).filter_by(**kwargs) + result = self.session.exec(statement).first() + + if result is None: + return None + + # Use the table's convert_to_model method + return result.convert_to_model(self) + + def set(self, object: Any, commit: bool = True): + """Save or update a model instance in the database.""" + from sqlmodel import select + from sqlalchemy.inspection import inspect + + # Find the table class for this model + table_link = next( + ( + link + for link in self._model_table_links + if link.model_cls is type(object) + ), + None, + ) + + if table_link is None: + return + + # Convert model to table instance + table_obj = table_link.table_cls.convert_from_model(object, self) + + # Get primary key columns + mapper = inspect(table_link.table_cls) + pk_cols = [col.name for col in mapper.primary_key] + + # Build query to check if exists + query = select(table_link.table_cls) + for pk_col in pk_cols: + query = query.where( + getattr(table_link.table_cls, pk_col) == getattr(table_obj, pk_col) + ) + + existing = self.session.exec(query).first() + + if existing: + # Update existing record + for key, value in table_obj.model_dump().items(): + setattr(existing, key, value) + self.session.add(existing) + else: + self.session.add(table_obj) + + if commit: + self.session.commit() + + def register_model_version(self, model_version): + """Register a model version with its model and seed objects. + This replaces all existing parameters, baseline parameter values, + and baseline variables for this model version.""" + # Add or update the model directly to avoid conflicts + from policyengine.utils.compress import compress_data + + from .baseline_parameter_value_table import BaselineParameterValueTable + from .baseline_variable_table import BaselineVariableTable + from .model_table import ModelTable + from .model_version_table import ModelVersionTable + from .parameter_table import ParameterTable + + existing_model = ( + self.session.query(ModelTable) + .filter(ModelTable.id == model_version.model.id) + .first() + ) + if not existing_model: + model_table = ModelTable( + id=model_version.model.id, + name=model_version.model.name, + description=model_version.model.description, + simulation_function=compress_data( + model_version.model.simulation_function + ), + ) + self.session.add(model_table) + self.session.flush() + + # Add or update the model version + existing_version = ( + self.session.query(ModelVersionTable) + .filter(ModelVersionTable.id == model_version.id) + .first() + ) + if not existing_version: + version_table = ModelVersionTable( + id=model_version.id, + model_id=model_version.model.id, + version=model_version.version, + description=model_version.description, + created_at=model_version.created_at, + ) + self.session.add(version_table) + self.session.flush() + + # Get seed objects from the model + seed_objects = model_version.model.create_seed_objects(model_version) + + # Delete ALL existing seed data for this model (not just this version) + # This ensures we start fresh with the new version's data + # Order matters due to foreign key constraints + + # First delete baseline parameter values (they reference parameters) + self.session.query(BaselineParameterValueTable).filter( + BaselineParameterValueTable.model_id == model_version.model.id + ).delete() + + # Then delete baseline variables for this model + self.session.query(BaselineVariableTable).filter( + BaselineVariableTable.model_id == model_version.model.id + ).delete() + + # Finally delete all parameters for this model + self.session.query(ParameterTable).filter( + ParameterTable.model_id == model_version.model.id + ).delete() + + self.session.commit() + + # Add all parameters first + for parameter in seed_objects.parameters: + # We need to add directly to session to avoid the autoflush issue + from .parameter_table import ParameterTable + + param_table = ParameterTable( + id=parameter.id, + model_id=parameter.model.id, # Now required as part of composite key + description=parameter.description, + data_type=parameter.data_type.__name__ + if parameter.data_type + else None, + label=parameter.label, + unit=parameter.unit, + ) + self.session.add(param_table) + + # Flush parameters to database so they exist for foreign key constraints + self.session.flush() + + # Add all baseline parameter values + for baseline_param_value in seed_objects.baseline_parameter_values: + import math + from uuid import uuid4 + + from .baseline_parameter_value_table import ( + BaselineParameterValueTable, + ) + + # Handle special float values that JSON doesn't support + value = baseline_param_value.value + if isinstance(value, float): + if math.isinf(value): + value = "Infinity" if value > 0 else "-Infinity" + elif math.isnan(value): + value = "NaN" + + bpv_table = BaselineParameterValueTable( + id=str(uuid4()), + parameter_id=baseline_param_value.parameter.id, + model_id=baseline_param_value.parameter.model.id, # Add model_id + model_version_id=baseline_param_value.model_version.id, + value=value, + start_date=baseline_param_value.start_date, + end_date=baseline_param_value.end_date, + ) + self.session.add(bpv_table) + + # Add all baseline variables + for baseline_variable in seed_objects.baseline_variables: + from .baseline_variable_table import BaselineVariableTable + + bv_table = BaselineVariableTable( + id=baseline_variable.id, + model_id=baseline_variable.model_version.model.id, # Add model_id + model_version_id=baseline_variable.model_version.id, + entity=baseline_variable.entity, + label=baseline_variable.label, + description=baseline_variable.description, + data_type=(lambda bv: compress_data(bv.data_type))( + baseline_variable + ) + if baseline_variable.data_type + else None, + ) + self.session.add(bv_table) + + # Commit everything at once + self.session.commit() diff --git a/build/lib/policyengine/database/dataset_table.py b/build/lib/policyengine/database/dataset_table.py new file mode 100644 index 00000000..cf22cda8 --- /dev/null +++ b/build/lib/policyengine/database/dataset_table.py @@ -0,0 +1,94 @@ +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine.models import Dataset, Model, VersionedDataset +from policyengine.utils.compress import compress_data, decompress_data + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class DatasetTable(SQLModel, table=True): + __tablename__ = "datasets" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + name: str = Field(nullable=False) + description: str | None = Field(default=None) + version: str | None = Field(default=None) + versioned_dataset_id: str | None = Field( + default=None, foreign_key="versioned_datasets.id", ondelete="SET NULL" + ) + year: int | None = Field(default=None) + data: bytes | None = Field(default=None) + model_id: str | None = Field( + default=None, foreign_key="models.id", ondelete="SET NULL" + ) + + @classmethod + def convert_from_model(cls, model: Dataset, database: "Database" = None) -> "DatasetTable": + """Convert a Dataset instance to a DatasetTable instance. + + Args: + model: The Dataset instance to convert + database: The database instance for persisting foreign objects if needed + + Returns: + A DatasetTable instance + """ + # Ensure foreign objects are persisted if database is provided + if database: + if model.versioned_dataset: + database.set(model.versioned_dataset, commit=False) + if model.model: + database.set(model.model, commit=False) + + return cls( + id=model.id, + name=model.name, + description=model.description, + version=model.version, + versioned_dataset_id=model.versioned_dataset.id if model.versioned_dataset else None, + year=model.year, + data=compress_data(model.data) if model.data else None, + model_id=model.model.id if model.model else None, + ) + + def convert_to_model(self, database: "Database" = None) -> Dataset: + """Convert this DatasetTable instance to a Dataset instance. + + Args: + database: The database instance for resolving foreign keys + + Returns: + A Dataset instance + """ + # Resolve foreign keys + versioned_dataset = None + model = None + + if database: + if self.versioned_dataset_id: + versioned_dataset = database.get(VersionedDataset, id=self.versioned_dataset_id) + if self.model_id: + model = database.get(Model, id=self.model_id) + + return Dataset( + id=self.id, + name=self.name, + description=self.description, + version=self.version, + versioned_dataset=versioned_dataset, + year=self.year, + data=decompress_data(self.data) if self.data else None, + model=model, + ) + + +dataset_table_link = TableLink( + model_cls=Dataset, + table_cls=DatasetTable, +) diff --git a/build/lib/policyengine/database/dynamic_table.py b/build/lib/policyengine/database/dynamic_table.py new file mode 100644 index 00000000..086e6bd9 --- /dev/null +++ b/build/lib/policyengine/database/dynamic_table.py @@ -0,0 +1,68 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine.models import Dynamic +from policyengine.utils.compress import compress_data, decompress_data + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class DynamicTable(SQLModel, table=True): + __tablename__ = "dynamics" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + name: str = Field(nullable=False) + description: str | None = Field(default=None) + simulation_modifier: bytes | None = Field(default=None) + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + + @classmethod + def convert_from_model(cls, model: Dynamic, database: "Database" = None) -> "DynamicTable": + """Convert a Dynamic instance to a DynamicTable instance. + + Args: + model: The Dynamic instance to convert + database: The database instance (not used for this table) + + Returns: + A DynamicTable instance + """ + return cls( + id=model.id, + name=model.name, + description=model.description, + simulation_modifier=compress_data(model.simulation_modifier) if model.simulation_modifier else None, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + def convert_to_model(self, database: "Database" = None) -> Dynamic: + """Convert this DynamicTable instance to a Dynamic instance. + + Args: + database: The database instance (not used for this table) + + Returns: + A Dynamic instance + """ + return Dynamic( + id=self.id, + name=self.name, + description=self.description, + simulation_modifier=decompress_data(self.simulation_modifier) if self.simulation_modifier else None, + created_at=self.created_at, + updated_at=self.updated_at, + ) + + +dynamic_table_link = TableLink( + model_cls=Dynamic, + table_cls=DynamicTable, +) diff --git a/build/lib/policyengine/database/link.py b/build/lib/policyengine/database/link.py new file mode 100644 index 00000000..2bb1a041 --- /dev/null +++ b/build/lib/policyengine/database/link.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel +from sqlmodel import SQLModel + + +class TableLink(BaseModel): + """Simple registry mapping model classes to table classes.""" + model_cls: type[BaseModel] + table_cls: type[SQLModel] diff --git a/build/lib/policyengine/database/model_table.py b/build/lib/policyengine/database/model_table.py new file mode 100644 index 00000000..220238c8 --- /dev/null +++ b/build/lib/policyengine/database/model_table.py @@ -0,0 +1,60 @@ +from typing import TYPE_CHECKING + +from sqlmodel import Field, SQLModel + +from policyengine.models import Model +from policyengine.utils.compress import compress_data, decompress_data + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class ModelTable(SQLModel, table=True, extend_existing=True): + __tablename__ = "models" + + id: str = Field(primary_key=True) + name: str = Field(nullable=False) + description: str | None = Field(default=None) + simulation_function: bytes + + @classmethod + def convert_from_model(cls, model: Model, database: "Database" = None) -> "ModelTable": + """Convert a Model instance to a ModelTable instance. + + Args: + model: The Model instance to convert + database: The database instance (not used for this table) + + Returns: + A ModelTable instance + """ + return cls( + id=model.id, + name=model.name, + description=model.description, + simulation_function=compress_data(model.simulation_function), + ) + + def convert_to_model(self, database: "Database" = None) -> Model: + """Convert this ModelTable instance to a Model instance. + + Args: + database: The database instance (not used for this table) + + Returns: + A Model instance + """ + return Model( + id=self.id, + name=self.name, + description=self.description, + simulation_function=decompress_data(self.simulation_function), + ) + + +model_table_link = TableLink( + model_cls=Model, + table_cls=ModelTable, +) diff --git a/build/lib/policyengine/database/model_version_table.py b/build/lib/policyengine/database/model_version_table.py new file mode 100644 index 00000000..86d19fed --- /dev/null +++ b/build/lib/policyengine/database/model_version_table.py @@ -0,0 +1,73 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine.models import Model, ModelVersion + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class ModelVersionTable(SQLModel, table=True): + __tablename__ = "model_versions" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + model_id: str = Field(foreign_key="models.id", ondelete="CASCADE") + version: str = Field(nullable=False) + description: str | None = Field(default=None) + created_at: datetime = Field(default_factory=datetime.now) + + @classmethod + def convert_from_model(cls, model: ModelVersion, database: "Database" = None) -> "ModelVersionTable": + """Convert a ModelVersion instance to a ModelVersionTable instance. + + Args: + model: The ModelVersion instance to convert + database: The database instance for persisting the model if needed + + Returns: + A ModelVersionTable instance + """ + # Ensure the Model is persisted if database is provided + if database and model.model: + database.set(model.model, commit=False) + + return cls( + id=model.id, + model_id=model.model.id if model.model else None, + version=model.version, + description=model.description, + created_at=model.created_at, + ) + + def convert_to_model(self, database: "Database" = None) -> ModelVersion: + """Convert this ModelVersionTable instance to a ModelVersion instance. + + Args: + database: The database instance for resolving the model foreign key + + Returns: + A ModelVersion instance + """ + # Resolve the model foreign key + model = None + if database and self.model_id: + model = database.get(Model, id=self.model_id) + + return ModelVersion( + id=self.id, + model=model, + version=self.version, + description=self.description, + created_at=self.created_at, + ) + + +model_version_table_link = TableLink( + model_cls=ModelVersion, + table_cls=ModelVersionTable, +) diff --git a/build/lib/policyengine/database/parameter_table.py b/build/lib/policyengine/database/parameter_table.py new file mode 100644 index 00000000..aef88e5a --- /dev/null +++ b/build/lib/policyengine/database/parameter_table.py @@ -0,0 +1,92 @@ +from typing import TYPE_CHECKING + +from sqlmodel import Field, SQLModel + +from policyengine.models import Model, Parameter + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class ParameterTable(SQLModel, table=True): + __tablename__ = "parameters" + __table_args__ = ({"extend_existing": True},) + + id: str = Field(primary_key=True) # Parameter name + model_id: str = Field( + primary_key=True, foreign_key="models.id" + ) # Part of composite key + description: str | None = Field(default=None) + data_type: str | None = Field(nullable=True) # Data type name + label: str | None = Field(default=None) + unit: str | None = Field(default=None) + + @classmethod + def convert_from_model(cls, model: Parameter, database: "Database" = None) -> "ParameterTable": + """Convert a Parameter instance to a ParameterTable instance. + + Args: + model: The Parameter instance to convert + database: The database instance for persisting the model if needed + + Returns: + A ParameterTable instance + """ + # Ensure the Model is persisted if database is provided + if database and model.model: + database.set(model.model, commit=False) + + return cls( + id=model.id, + model_id=model.model.id if model.model else None, + description=model.description, + data_type=model.data_type.__name__ if model.data_type else None, + label=model.label, + unit=model.unit, + ) + + def convert_to_model(self, database: "Database" = None) -> Parameter: + """Convert this ParameterTable instance to a Parameter instance. + + Args: + database: The database instance for resolving the model foreign key + + Returns: + A Parameter instance + """ + from .model_table import ModelTable + from sqlmodel import select + + # Resolve the model foreign key + model = None + if database and self.model_id: + model_table = database.session.exec( + select(ModelTable).where(ModelTable.id == self.model_id) + ).first() + if model_table: + model = model_table.convert_to_model(database) + + # Convert data_type string back to type + data_type = None + if self.data_type: + try: + data_type = eval(self.data_type) + except: + data_type = None + + return Parameter( + id=self.id, + description=self.description, + data_type=data_type, + model=model, + label=self.label, + unit=self.unit, + ) + + +parameter_table_link = TableLink( + model_cls=Parameter, + table_cls=ParameterTable, +) diff --git a/build/lib/policyengine/database/parameter_value_table.py b/build/lib/policyengine/database/parameter_value_table.py new file mode 100644 index 00000000..7bd02d0a --- /dev/null +++ b/build/lib/policyengine/database/parameter_value_table.py @@ -0,0 +1,108 @@ +from datetime import datetime +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from sqlmodel import JSON, Column, Field, SQLModel + +from policyengine.models import Parameter, ParameterValue + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class ParameterValueTable(SQLModel, table=True): + __tablename__ = "parameter_values" + __table_args__ = ({"extend_existing": True},) + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + parameter_id: str = Field(nullable=False) # Part of composite foreign key + model_id: str = Field(nullable=False) # Part of composite foreign key + policy_id: str | None = Field(default=None, foreign_key="policies.id", ondelete="CASCADE") # Link to policy + value: Any | None = Field( + default=None, sa_column=Column(JSON) + ) # JSON field for any type + start_date: datetime = Field(nullable=False) + end_date: datetime | None = Field(default=None) + + @classmethod + def convert_from_model(cls, model: ParameterValue, database: "Database" = None) -> "ParameterValueTable": + """Convert a ParameterValue instance to a ParameterValueTable instance. + + Args: + model: The ParameterValue instance to convert + database: The database instance for persisting the parameter if needed + + Returns: + A ParameterValueTable instance + """ + import math + + # Ensure the Parameter is persisted if database is provided + if database and model.parameter: + database.set(model.parameter, commit=False) + + # Handle special float values + value = model.value + if isinstance(value, float): + if math.isinf(value): + value = "Infinity" if value > 0 else "-Infinity" + elif math.isnan(value): + value = "NaN" + + return cls( + id=model.id, + parameter_id=model.parameter.id if model.parameter else None, + model_id=model.parameter.model.id if model.parameter and model.parameter.model else None, + value=value, + start_date=model.start_date, + end_date=model.end_date, + ) + + def convert_to_model(self, database: "Database" = None) -> ParameterValue: + """Convert this ParameterValueTable instance to a ParameterValue instance. + + Args: + database: The database instance for resolving the parameter foreign key + + Returns: + A ParameterValue instance + """ + from .parameter_table import ParameterTable + from sqlmodel import select + + # Resolve the parameter foreign key + parameter = None + if database and self.parameter_id and self.model_id: + param_table = database.session.exec( + select(ParameterTable).where( + ParameterTable.id == self.parameter_id, + ParameterTable.model_id == self.model_id + ) + ).first() + if param_table: + parameter = param_table.convert_to_model(database) + + # Handle special string values + value = self.value + if value == "Infinity": + value = float("inf") + elif value == "-Infinity": + value = float("-inf") + elif value == "NaN": + value = float("nan") + + return ParameterValue( + id=self.id, + parameter=parameter, + value=value, + start_date=self.start_date, + end_date=self.end_date, + ) + + +parameter_value_table_link = TableLink( + model_cls=ParameterValue, + table_cls=ParameterValueTable, +) diff --git a/build/lib/policyengine/database/policy_table.py b/build/lib/policyengine/database/policy_table.py new file mode 100644 index 00000000..0ae381e4 --- /dev/null +++ b/build/lib/policyengine/database/policy_table.py @@ -0,0 +1,136 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine.models import Policy +from policyengine.utils.compress import compress_data, decompress_data + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class PolicyTable(SQLModel, table=True): + __tablename__ = "policies" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + name: str = Field(nullable=False) + description: str | None = Field(default=None) + simulation_modifier: bytes | None = Field(default=None) + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + + @classmethod + def convert_from_model(cls, model: Policy, database: "Database" = None) -> "PolicyTable": + """Convert a Policy instance to a PolicyTable instance. + + Args: + model: The Policy instance to convert + database: The database instance for persisting nested objects + + Returns: + A PolicyTable instance + """ + policy_table = cls( + id=model.id, + name=model.name, + description=model.description, + simulation_modifier=compress_data(model.simulation_modifier) if model.simulation_modifier else None, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + # Handle nested parameter values if database is provided + if database and model.parameter_values: + from .parameter_value_table import ParameterValueTable + from sqlmodel import select + + # First ensure the policy table is saved to the database + # This is necessary so the foreign key constraint is satisfied + # Check if it already exists + existing_policy = database.session.exec( + select(PolicyTable).where(PolicyTable.id == model.id) + ).first() + + if not existing_policy: + database.session.add(policy_table) + database.session.flush() + + # Track which parameter value IDs we want to keep + desired_pv_ids = {pv.id for pv in model.parameter_values} + + # Delete only parameter values linked to this policy that are NOT in the new list + existing_pvs = database.session.exec( + select(ParameterValueTable).where(ParameterValueTable.policy_id == model.id) + ).all() + for pv in existing_pvs: + if pv.id not in desired_pv_ids: + database.session.delete(pv) + + # Now save/update the parameter values + for param_value in model.parameter_values: + # Check if this parameter value already exists in the database + existing_pv = database.session.exec( + select(ParameterValueTable).where(ParameterValueTable.id == param_value.id) + ).first() + + if existing_pv: + # Update existing parameter value + pv_table = ParameterValueTable.convert_from_model(param_value, database) + existing_pv.parameter_id = pv_table.parameter_id + existing_pv.model_id = pv_table.model_id + existing_pv.policy_id = model.id + existing_pv.value = pv_table.value + existing_pv.start_date = pv_table.start_date + existing_pv.end_date = pv_table.end_date + else: + # Create new parameter value + pv_table = ParameterValueTable.convert_from_model(param_value, database) + pv_table.policy_id = model.id # Link to this policy + database.session.add(pv_table) + database.session.flush() + + return policy_table + + def convert_to_model(self, database: "Database" = None) -> Policy: + """Convert this PolicyTable instance to a Policy instance. + + Args: + database: The database instance for loading nested objects + + Returns: + A Policy instance + """ + # Load nested parameter values if database is provided + parameter_values = [] + if database: + from .parameter_value_table import ParameterValueTable + from sqlmodel import select + + # Query for all parameter values linked to this policy + pv_tables = database.session.exec( + select(ParameterValueTable).where(ParameterValueTable.policy_id == self.id) + ).all() + + # Convert each one to a model + for pv_table in pv_tables: + parameter_values.append(pv_table.convert_to_model(database)) + + return Policy( + id=self.id, + name=self.name, + description=self.description, + parameter_values=parameter_values, + simulation_modifier=decompress_data(self.simulation_modifier) if self.simulation_modifier else None, + created_at=self.created_at, + updated_at=self.updated_at, + ) + + +policy_table_link = TableLink( + model_cls=Policy, + table_cls=PolicyTable, +) diff --git a/build/lib/policyengine/database/report_element_table.py b/build/lib/policyengine/database/report_element_table.py new file mode 100644 index 00000000..cc69e83e --- /dev/null +++ b/build/lib/policyengine/database/report_element_table.py @@ -0,0 +1,106 @@ +import uuid +from datetime import datetime + +from sqlmodel import Field, SQLModel, Column, JSON +from typing import TYPE_CHECKING + +from policyengine.models.report_element import ReportElement + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class ReportElementTable(SQLModel, table=True, extend_existing=True): + __tablename__ = "report_elements" + + id: str = Field( + primary_key=True, default_factory=lambda: str(uuid.uuid4()) + ) + label: str = Field(nullable=False) + type: str = Field(nullable=False) # "chart" or "markdown" + + # Data source + data_table: str | None = Field(default=None) # "aggregates" or "aggregate_changes" + + # Chart configuration + chart_type: str | None = Field( + default=None + ) # "bar", "line", "scatter", "area", "pie" + x_axis_variable: str | None = Field(default=None) + y_axis_variable: str | None = Field(default=None) + group_by: str | None = Field(default=None) + color_by: str | None = Field(default=None) + size_by: str | None = Field(default=None) + + # Markdown specific + markdown_content: str | None = Field(default=None) + + # Metadata + report_id: str | None = Field(default=None, foreign_key="reports.id") + user_id: str | None = Field(default=None, foreign_key="users.id") + model_version_id: str | None = Field(default=None, foreign_key="model_versions.id") + position: int | None = Field(default=None) + visible: bool | None = Field(default=True) + custom_config: dict | None = Field(default=None, sa_column=Column(JSON)) + report_element_metadata: dict | None = Field(default=None, sa_column=Column(JSON)) + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + + @classmethod + def convert_from_model(cls, model: ReportElement, database: "Database" = None) -> "ReportElementTable": + """Convert a ReportElement instance to a ReportElementTable instance.""" + return cls( + id=model.id, + label=model.label, + type=model.type, + data_table=model.data_table, + chart_type=model.chart_type, + x_axis_variable=model.x_axis_variable, + y_axis_variable=model.y_axis_variable, + group_by=model.group_by, + color_by=model.color_by, + size_by=model.size_by, + markdown_content=model.markdown_content, + report_id=model.report_id, + user_id=model.user_id, + model_version_id=model.model_version_id, + position=model.position, + visible=model.visible, + custom_config=model.custom_config, + report_element_metadata=model.report_element_metadata, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + def convert_to_model(self, database: "Database" = None) -> ReportElement: + """Convert this ReportElementTable instance to a ReportElement instance.""" + return ReportElement( + id=self.id, + label=self.label, + type=self.type, + data_table=self.data_table, + chart_type=self.chart_type, + x_axis_variable=self.x_axis_variable, + y_axis_variable=self.y_axis_variable, + group_by=self.group_by, + color_by=self.color_by, + size_by=self.size_by, + markdown_content=self.markdown_content, + report_id=self.report_id, + user_id=self.user_id, + model_version_id=self.model_version_id, + position=self.position, + visible=self.visible, + custom_config=self.custom_config, + report_element_metadata=self.report_element_metadata, + created_at=self.created_at, + updated_at=self.updated_at, + ) + + +report_element_table_link = TableLink( + model_cls=ReportElement, + table_cls=ReportElementTable, +) diff --git a/build/lib/policyengine/database/report_table.py b/build/lib/policyengine/database/report_table.py new file mode 100644 index 00000000..79c11cf0 --- /dev/null +++ b/build/lib/policyengine/database/report_table.py @@ -0,0 +1,120 @@ +import uuid +from datetime import datetime + +from sqlmodel import Field, SQLModel +from typing import TYPE_CHECKING + +from policyengine.models.report import Report + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class ReportTable(SQLModel, table=True, extend_existing=True): + __tablename__ = "reports" + + id: str = Field( + primary_key=True, default_factory=lambda: str(uuid.uuid4()) + ) + label: str = Field(nullable=False) + created_at: datetime = Field(default_factory=datetime.utcnow) + + @classmethod + def convert_from_model(cls, model: Report, database: "Database" = None) -> "ReportTable": + """Convert a Report instance to a ReportTable instance.""" + report_table = cls( + id=model.id, + label=model.label, + created_at=model.created_at, + ) + + # Handle nested report elements if database is provided + if database and model.elements: + from .report_element_table import ReportElementTable + from sqlmodel import select + + # First ensure the report table is saved to the database + # This is necessary so the foreign key constraint is satisfied + # Check if it already exists + existing_report = database.session.exec( + select(ReportTable).where(ReportTable.id == model.id) + ).first() + + if not existing_report: + database.session.add(report_table) + database.session.flush() + + # Track which element IDs we want to keep + desired_elem_ids = {elem.id for elem in model.elements} + + # Delete only elements linked to this report that are NOT in the new list + existing_elems = database.session.exec( + select(ReportElementTable).where(ReportElementTable.report_id == model.id) + ).all() + for elem in existing_elems: + if elem.id not in desired_elem_ids: + database.session.delete(elem) + + # Now save/update the elements + for i, element in enumerate(model.elements): + # Check if this element already exists in the database + existing_elem = database.session.exec( + select(ReportElementTable).where(ReportElementTable.id == element.id) + ).first() + + if existing_elem: + # Update existing element + elem_table = ReportElementTable.convert_from_model(element, database) + existing_elem.report_id = model.id + existing_elem.position = i + existing_elem.label = elem_table.label + existing_elem.type = elem_table.type + existing_elem.markdown_content = elem_table.markdown_content + existing_elem.chart_type = elem_table.chart_type + existing_elem.x_axis_variable = elem_table.x_axis_variable + existing_elem.y_axis_variable = elem_table.y_axis_variable + existing_elem.baseline_simulation_id = elem_table.baseline_simulation_id + existing_elem.reform_simulation_id = elem_table.reform_simulation_id + else: + # Create new element + elem_table = ReportElementTable.convert_from_model(element, database) + elem_table.report_id = model.id # Link to this report + elem_table.position = i # Maintain order + database.session.add(elem_table) + database.session.flush() + + return report_table + + def convert_to_model(self, database: "Database" = None) -> Report: + """Convert this ReportTable instance to a Report instance.""" + # Load nested report elements if database is provided + elements = [] + if database: + from .report_element_table import ReportElementTable + from sqlmodel import select + + # Query for all elements linked to this report, ordered by position + elem_tables = database.session.exec( + select(ReportElementTable) + .where(ReportElementTable.report_id == self.id) + .order_by(ReportElementTable.position) + ).all() + + # Convert each one to a model + for elem_table in elem_tables: + elements.append(elem_table.convert_to_model(database)) + + return Report( + id=self.id, + label=self.label, + created_at=self.created_at, + elements=elements, + ) + + +report_table_link = TableLink( + model_cls=Report, + table_cls=ReportTable, +) diff --git a/build/lib/policyengine/database/simulation_table.py b/build/lib/policyengine/database/simulation_table.py new file mode 100644 index 00000000..de45a419 --- /dev/null +++ b/build/lib/policyengine/database/simulation_table.py @@ -0,0 +1,225 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine.models import Dataset, Dynamic, Model, ModelVersion, Policy, Simulation +from policyengine.utils.compress import compress_data, decompress_data + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class SimulationTable(SQLModel, table=True): + __tablename__ = "simulations" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + + policy_id: str | None = Field( + default=None, foreign_key="policies.id", ondelete="SET NULL" + ) + dynamic_id: str | None = Field( + default=None, foreign_key="dynamics.id", ondelete="SET NULL" + ) + dataset_id: str = Field(foreign_key="datasets.id", ondelete="CASCADE") + model_id: str = Field(foreign_key="models.id", ondelete="CASCADE") + model_version_id: str | None = Field( + default=None, foreign_key="model_versions.id", ondelete="SET NULL" + ) + + result: bytes | None = Field(default=None) + + @classmethod + def convert_from_model(cls, model: Simulation, database: "Database" = None) -> "SimulationTable": + """Convert a Simulation instance to a SimulationTable instance. + + Args: + model: The Simulation instance to convert + database: The database instance for persisting foreign objects if needed + + Returns: + A SimulationTable instance + """ + # Ensure all foreign objects are persisted if database is provided + if database: + if model.policy: + database.set(model.policy, commit=False) + if model.dynamic: + database.set(model.dynamic, commit=False) + if model.dataset: + database.set(model.dataset, commit=False) + if model.model: + database.set(model.model, commit=False) + if model.model_version: + database.set(model.model_version, commit=False) + + sim_table = cls( + id=model.id, + created_at=model.created_at, + updated_at=model.updated_at, + policy_id=model.policy.id if model.policy else None, + dynamic_id=model.dynamic.id if model.dynamic else None, + dataset_id=model.dataset.id if model.dataset else None, + model_id=model.model.id if model.model else None, + model_version_id=model.model_version.id if model.model_version else None, + result=compress_data(model.result) if model.result else None, + ) + + # Handle nested aggregates if database is provided + if database and model.aggregates: + from .aggregate import AggregateTable + from sqlmodel import select + + # First ensure the simulation table is saved to the database + # This is necessary so the foreign key constraint is satisfied + # Check if it already exists + existing_sim = database.session.exec( + select(SimulationTable).where(SimulationTable.id == model.id) + ).first() + + if not existing_sim: + database.session.add(sim_table) + database.session.flush() + + # Track which aggregate IDs we want to keep + desired_agg_ids = {agg.id for agg in model.aggregates} + + # Delete only aggregates linked to this simulation that are NOT in the new list + existing_aggs = database.session.exec( + select(AggregateTable).where(AggregateTable.simulation_id == model.id) + ).all() + for agg in existing_aggs: + if agg.id not in desired_agg_ids: + database.session.delete(agg) + + # Now save/update the aggregates + for aggregate in model.aggregates: + # Check if this aggregate already exists in the database + existing_agg = database.session.exec( + select(AggregateTable).where(AggregateTable.id == aggregate.id) + ).first() + + if existing_agg: + # Update existing aggregate + agg_table = AggregateTable.convert_from_model(aggregate, database) + existing_agg.simulation_id = agg_table.simulation_id + existing_agg.entity = agg_table.entity + existing_agg.variable_name = agg_table.variable_name + existing_agg.year = agg_table.year + existing_agg.filter_variable_name = agg_table.filter_variable_name + existing_agg.filter_variable_value = agg_table.filter_variable_value + existing_agg.filter_variable_leq = agg_table.filter_variable_leq + existing_agg.filter_variable_geq = agg_table.filter_variable_geq + existing_agg.aggregate_function = agg_table.aggregate_function + existing_agg.value = agg_table.value + else: + # Create new aggregate + agg_table = AggregateTable.convert_from_model(aggregate, database) + database.session.add(agg_table) + database.session.flush() + + return sim_table + + def convert_to_model(self, database: "Database" = None) -> Simulation: + """Convert this SimulationTable instance to a Simulation instance. + + Args: + database: The database instance for resolving foreign keys + + Returns: + A Simulation instance + """ + from sqlmodel import select + + from .model_version_table import ModelVersionTable + from .policy_table import PolicyTable + from .dataset_table import DatasetTable + from .model_table import ModelTable + from .dynamic_table import DynamicTable + + # Resolve all foreign keys + policy = None + dynamic = None + dataset = None + model = None + model_version = None + + if database: + if self.policy_id: + policy_table = database.session.exec( + select(PolicyTable).where(PolicyTable.id == self.policy_id) + ).first() + if policy_table: + policy = policy_table.convert_to_model(database) + + if self.dynamic_id: + try: + dynamic_table = database.session.exec( + select(DynamicTable).where(DynamicTable.id == self.dynamic_id) + ).first() + if dynamic_table: + dynamic = dynamic_table.convert_to_model(database) + except: + # Dynamic table might not be defined yet + dynamic = database.get(Dynamic, id=self.dynamic_id) + + if self.dataset_id: + dataset_table = database.session.exec( + select(DatasetTable).where(DatasetTable.id == self.dataset_id) + ).first() + if dataset_table: + dataset = dataset_table.convert_to_model(database) + + if self.model_id: + model_table = database.session.exec( + select(ModelTable).where(ModelTable.id == self.model_id) + ).first() + if model_table: + model = model_table.convert_to_model(database) + + if self.model_version_id: + version_table = database.session.exec( + select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) + ).first() + if version_table: + model_version = version_table.convert_to_model(database) + + # Load aggregates + aggregates = [] + if database: + from .aggregate import AggregateTable + from sqlmodel import select + + agg_tables = database.session.exec( + select(AggregateTable).where(AggregateTable.simulation_id == self.id) + ).all() + + for agg_table in agg_tables: + # Don't pass database to avoid circular reference issues + # The simulation reference will be set separately + agg_model = agg_table.convert_to_model(None) + aggregates.append(agg_model) + + return Simulation( + id=self.id, + created_at=self.created_at, + updated_at=self.updated_at, + policy=policy, + dynamic=dynamic, + dataset=dataset, + model=model, + model_version=model_version, + result=decompress_data(self.result) if self.result else None, + aggregates=aggregates, + ) + + +simulation_table_link = TableLink( + model_cls=Simulation, + table_cls=SimulationTable, +) diff --git a/build/lib/policyengine/database/table_mixin.py b/build/lib/policyengine/database/table_mixin.py new file mode 100644 index 00000000..a29cdeb6 --- /dev/null +++ b/build/lib/policyengine/database/table_mixin.py @@ -0,0 +1,80 @@ +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar + +from pydantic import BaseModel +from sqlmodel import SQLModel + +if TYPE_CHECKING: + from .database import Database + +T = TypeVar("T", bound=BaseModel) + + +class TableConversionMixin: + """Mixin class for SQLModel tables to provide conversion methods between table instances and Pydantic models.""" + + _model_cls: ClassVar[type[BaseModel]] = None + _foreign_key_fields: ClassVar[dict[str, type[BaseModel]]] = {} + + @classmethod + def convert_from_model(cls, model: BaseModel, database: "Database" = None) -> SQLModel: + """Convert a Pydantic model instance to a table instance, resolving foreign objects to IDs. + + Args: + model: The Pydantic model instance to convert + database: The database instance for resolving foreign objects (optional) + + Returns: + An instance of the SQLModel table class + """ + data = {} + + for field_name in cls.__annotations__.keys(): + # Check if this field is a foreign key that needs resolution + if field_name in cls._foreign_key_fields: + # Extract ID from the nested object + nested_obj = getattr(model, field_name.replace("_id", ""), None) + if nested_obj: + # If we need to ensure the foreign object exists in DB + if database: + database.set(nested_obj, commit=False) + data[field_name] = nested_obj.id if hasattr(nested_obj, "id") else None + else: + data[field_name] = None + elif hasattr(model, field_name): + # Direct field mapping + data[field_name] = getattr(model, field_name) + + return cls(**data) + + @classmethod + def convert_to_model(cls, table_instance: SQLModel, database: "Database" = None) -> BaseModel: + """Convert a table instance to a Pydantic model, resolving foreign key IDs to objects. + + Args: + table_instance: The SQLModel table instance to convert + database: The database instance for resolving foreign keys (required if foreign keys exist) + + Returns: + An instance of the Pydantic model class + """ + if cls._model_cls is None: + raise ValueError(f"Model class not set for {cls.__name__}") + + data = {} + + for field_name in cls._model_cls.__annotations__.keys(): + # Check if we need to resolve a foreign key + fk_field = f"{field_name}_id" + if fk_field in cls._foreign_key_fields and database: + # Resolve the foreign key to an object + fk_id = getattr(table_instance, fk_field, None) + if fk_id: + foreign_model_cls = cls._foreign_key_fields[fk_field] + data[field_name] = database.get(foreign_model_cls, id=fk_id) + else: + data[field_name] = None + elif hasattr(table_instance, field_name): + # Direct field mapping + data[field_name] = getattr(table_instance, field_name) + + return cls._model_cls(**data) \ No newline at end of file diff --git a/build/lib/policyengine/database/user_table.py b/build/lib/policyengine/database/user_table.py new file mode 100644 index 00000000..d663ac8f --- /dev/null +++ b/build/lib/policyengine/database/user_table.py @@ -0,0 +1,57 @@ +import uuid +from datetime import datetime + +from sqlmodel import Field, SQLModel +from typing import TYPE_CHECKING + +from policyengine.models.user import User + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class UserTable(SQLModel, table=True, extend_existing=True): + __tablename__ = "users" + + id: str = Field( + primary_key=True, default_factory=lambda: str(uuid.uuid4()) + ) + username: str = Field(nullable=False, unique=True) + first_name: str | None = Field(default=None) + last_name: str | None = Field(default=None) + email: str | None = Field(default=None) + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + + @classmethod + def convert_from_model(cls, model: User, database: "Database" = None) -> "UserTable": + """Convert a User instance to a UserTable instance.""" + return cls( + id=model.id, + username=model.username, + first_name=model.first_name, + last_name=model.last_name, + email=model.email, + created_at=model.created_at, + updated_at=model.updated_at, + ) + + def convert_to_model(self, database: "Database" = None) -> User: + """Convert this UserTable instance to a User instance.""" + return User( + id=self.id, + username=self.username, + first_name=self.first_name, + last_name=self.last_name, + email=self.email, + created_at=self.created_at, + updated_at=self.updated_at, + ) + + +user_table_link = TableLink( + model_cls=User, + table_cls=UserTable, +) diff --git a/build/lib/policyengine/database/versioned_dataset_table.py b/build/lib/policyengine/database/versioned_dataset_table.py new file mode 100644 index 00000000..4e1524c9 --- /dev/null +++ b/build/lib/policyengine/database/versioned_dataset_table.py @@ -0,0 +1,45 @@ +from uuid import uuid4 + +from sqlmodel import Field, SQLModel +from typing import TYPE_CHECKING + +from policyengine.models import VersionedDataset + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class VersionedDatasetTable(SQLModel, table=True): + __tablename__ = "versioned_datasets" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + name: str = Field(nullable=False) + description: str = Field(nullable=False) + model_id: str | None = Field( + default=None, foreign_key="models.id", ondelete="SET NULL" + ) + + @classmethod + def convert_from_model(cls, model: VersionedDataset, database: "Database" = None) -> "VersionedDatasetTable": + """Convert a VersionedDataset instance to a VersionedDatasetTable instance.""" + return cls( + id=model.id, + name=model.name, + description=model.description, + ) + + def convert_to_model(self, database: "Database" = None) -> VersionedDataset: + """Convert this VersionedDatasetTable instance to a VersionedDataset instance.""" + return VersionedDataset( + id=self.id, + name=self.name, + description=self.description, + ) + + +versioned_dataset_table_link = TableLink( + model_cls=VersionedDataset, + table_cls=VersionedDatasetTable, +) diff --git a/build/lib/policyengine/models/__init__.py b/build/lib/policyengine/models/__init__.py new file mode 100644 index 00000000..de5fd8c9 --- /dev/null +++ b/build/lib/policyengine/models/__init__.py @@ -0,0 +1,39 @@ +from .aggregate import Aggregate as Aggregate +from .aggregate import AggregateType as AggregateType +from .aggregate_change import AggregateChange as AggregateChange +from .baseline_parameter_value import ( + BaselineParameterValue as BaselineParameterValue, +) +from .baseline_variable import BaselineVariable as BaselineVariable +from .dataset import Dataset as Dataset +from .dynamic import Dynamic as Dynamic +from .model import Model as Model +from .model_version import ModelVersion as ModelVersion +from .parameter import Parameter as Parameter +from .parameter_value import ParameterValue as ParameterValue +from .policy import Policy as Policy +from .policyengine_uk import ( + policyengine_uk_latest_version as policyengine_uk_latest_version, +) +from .policyengine_uk import ( + policyengine_uk_model as policyengine_uk_model, +) +from .policyengine_us import ( + policyengine_us_latest_version as policyengine_us_latest_version, +) +from .policyengine_us import ( + policyengine_us_model as policyengine_us_model, +) +from .report import Report as Report +from .report_element import ReportElement as ReportElement +from .simulation import Simulation as Simulation +from .user import User as User +from .versioned_dataset import VersionedDataset as VersionedDataset + +# Rebuild models to handle circular references +from .aggregate import Aggregate +from .aggregate_change import AggregateChange +from .simulation import Simulation +Aggregate.model_rebuild() +AggregateChange.model_rebuild() +Simulation.model_rebuild() diff --git a/build/lib/policyengine/models/aggregate.py b/build/lib/policyengine/models/aggregate.py new file mode 100644 index 00000000..031cad87 --- /dev/null +++ b/build/lib/policyengine/models/aggregate.py @@ -0,0 +1,132 @@ +from enum import Enum +from typing import TYPE_CHECKING, Literal +from uuid import uuid4 + +import pandas as pd +from microdf import MicroDataFrame +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from policyengine.models import Simulation + + +class AggregateType(str, Enum): + SUM = "sum" + MEAN = "mean" + MEDIAN = "median" + COUNT = "count" + + +class Aggregate(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + simulation: "Simulation | None" = None + entity: str + variable_name: str + year: int | None = None + filter_variable_name: str | None = None + filter_variable_value: str | None = None + filter_variable_leq: float | None = None + filter_variable_geq: float | None = None + aggregate_function: Literal[ + AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT + ] + reportelement_id: str | None = None + + value: float | None = None + + @staticmethod + def run(aggregates: list["Aggregate"]) -> list["Aggregate"]: + """Process aggregates, handling multiple simulations if necessary.""" + # Group aggregates by simulation + simulation_groups = {} + for agg in aggregates: + sim_id = id(agg.simulation) if agg.simulation else None + if sim_id not in simulation_groups: + simulation_groups[sim_id] = [] + simulation_groups[sim_id].append(agg) + + # Process each simulation group separately + all_results = [] + for sim_id, sim_aggregates in simulation_groups.items(): + if not sim_aggregates: + continue + + # Get the simulation from the first aggregate in this group + simulation = sim_aggregates[0].simulation + if simulation is None: + raise ValueError("Aggregate has no simulation attached") + + # Process this simulation's aggregates + group_results = Aggregate._process_simulation_aggregates( + sim_aggregates, simulation + ) + all_results.extend(group_results) + + return all_results + + @staticmethod + def _process_simulation_aggregates( + aggregates: list["Aggregate"], simulation: "Simulation" + ) -> list["Aggregate"]: + """Process aggregates for a single simulation.""" + results = [] + + tables = simulation.result + # Copy tables to ensure we don't modify original dataframes + tables = {k: v.copy() for k, v in tables.items()} + for table in tables: + tables[table] = pd.DataFrame(tables[table]) + weight_col = f"{table}_weight" + if weight_col in tables[table].columns: + tables[table] = MicroDataFrame( + tables[table], weights=weight_col + ) + + for agg in aggregates: + if agg.entity not in tables: + raise ValueError( + f"Entity {agg.entity} not found in simulation results" + ) + table = tables[agg.entity] + + if agg.variable_name not in table.columns: + raise ValueError( + f"Variable {agg.variable_name} not found in entity {agg.entity}" + ) + + df = table + + if agg.year is None: + agg.year = simulation.dataset.year + + if agg.filter_variable_name is not None: + if agg.filter_variable_name not in df.columns: + raise ValueError( + f"Filter variable {agg.filter_variable_name} not found in entity {agg.entity}" + ) + if agg.filter_variable_value is not None: + df = df[ + df[agg.filter_variable_name] + == agg.filter_variable_value + ] + if agg.filter_variable_leq is not None: + df = df[ + df[agg.filter_variable_name] <= agg.filter_variable_leq + ] + if agg.filter_variable_geq is not None: + df = df[ + df[agg.filter_variable_name] >= agg.filter_variable_geq + ] + + if agg.aggregate_function == AggregateType.SUM: + agg.value = float(df[agg.variable_name].sum()) + elif agg.aggregate_function == AggregateType.MEAN: + agg.value = float(df[agg.variable_name].mean()) + elif agg.aggregate_function == AggregateType.MEDIAN: + agg.value = float(df[agg.variable_name].median()) + elif agg.aggregate_function == AggregateType.COUNT: + agg.value = float((df[agg.variable_name] > 0).sum()) + + results.append(agg) + + return results diff --git a/build/lib/policyengine/models/aggregate_change.py b/build/lib/policyengine/models/aggregate_change.py new file mode 100644 index 00000000..e0a400df --- /dev/null +++ b/build/lib/policyengine/models/aggregate_change.py @@ -0,0 +1,143 @@ +from enum import Enum +from typing import TYPE_CHECKING, Literal +from uuid import uuid4 + +import pandas as pd +from microdf import MicroDataFrame +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from policyengine.models import Simulation + + +class AggregateType(str, Enum): + SUM = "sum" + MEAN = "mean" + MEDIAN = "median" + COUNT = "count" + + +class AggregateChange(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + baseline_simulation: "Simulation | None" = None + comparison_simulation: "Simulation | None" = None + entity: str + variable_name: str + year: int | None = None + filter_variable_name: str | None = None + filter_variable_value: str | None = None + filter_variable_leq: float | None = None + filter_variable_geq: float | None = None + aggregate_function: Literal[ + AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT + ] + reportelement_id: str | None = None + + baseline_value: float | None = None + comparison_value: float | None = None + change: float | None = None + relative_change: float | None = None + + @staticmethod + def run(aggregate_changes: list["AggregateChange"]) -> list["AggregateChange"]: + """Process aggregate changes, handling multiple simulation pairs.""" + results = [] + + for agg_change in aggregate_changes: + if agg_change.baseline_simulation is None: + raise ValueError("AggregateChange has no baseline simulation attached") + if agg_change.comparison_simulation is None: + raise ValueError("AggregateChange has no comparison simulation attached") + + # Compute baseline value + baseline_value = AggregateChange._compute_single_aggregate( + agg_change, agg_change.baseline_simulation + ) + + # Compute comparison value + comparison_value = AggregateChange._compute_single_aggregate( + agg_change, agg_change.comparison_simulation + ) + + # Compute changes + agg_change.baseline_value = baseline_value + agg_change.comparison_value = comparison_value + agg_change.change = comparison_value - baseline_value + + # Compute relative change (avoiding division by zero) + if baseline_value != 0: + agg_change.relative_change = (comparison_value - baseline_value) / abs(baseline_value) + else: + agg_change.relative_change = None if comparison_value == 0 else float('inf') + + results.append(agg_change) + + return results + + @staticmethod + def _compute_single_aggregate( + agg_change: "AggregateChange", simulation: "Simulation" + ) -> float: + """Compute aggregate value for a single simulation.""" + tables = simulation.result + # Copy tables to ensure we don't modify original dataframes + tables = {k: v.copy() for k, v in tables.items()} + + for table in tables: + tables[table] = pd.DataFrame(tables[table]) + weight_col = f"{table}_weight" + if weight_col in tables[table].columns: + tables[table] = MicroDataFrame( + tables[table], weights=weight_col + ) + + if agg_change.entity not in tables: + raise ValueError( + f"Entity {agg_change.entity} not found in simulation results" + ) + + table = tables[agg_change.entity] + + if agg_change.variable_name not in table.columns: + raise ValueError( + f"Variable {agg_change.variable_name} not found in entity {agg_change.entity}" + ) + + df = table + + if agg_change.year is None: + agg_change.year = simulation.dataset.year + + # Apply filters + if agg_change.filter_variable_name is not None: + if agg_change.filter_variable_name not in df.columns: + raise ValueError( + f"Filter variable {agg_change.filter_variable_name} not found in entity {agg_change.entity}" + ) + if agg_change.filter_variable_value is not None: + df = df[ + df[agg_change.filter_variable_name] + == agg_change.filter_variable_value + ] + if agg_change.filter_variable_leq is not None: + df = df[ + df[agg_change.filter_variable_name] <= agg_change.filter_variable_leq + ] + if agg_change.filter_variable_geq is not None: + df = df[ + df[agg_change.filter_variable_name] >= agg_change.filter_variable_geq + ] + + # Compute aggregate + if agg_change.aggregate_function == AggregateType.SUM: + value = float(df[agg_change.variable_name].sum()) + elif agg_change.aggregate_function == AggregateType.MEAN: + value = float(df[agg_change.variable_name].mean()) + elif agg_change.aggregate_function == AggregateType.MEDIAN: + value = float(df[agg_change.variable_name].median()) + elif agg_change.aggregate_function == AggregateType.COUNT: + value = float((df[agg_change.variable_name] > 0).sum()) + else: + raise ValueError(f"Unknown aggregate function: {agg_change.aggregate_function}") + + return value \ No newline at end of file diff --git a/build/lib/policyengine/models/baseline_parameter_value.py b/build/lib/policyengine/models/baseline_parameter_value.py new file mode 100644 index 00000000..8afb6e22 --- /dev/null +++ b/build/lib/policyengine/models/baseline_parameter_value.py @@ -0,0 +1,16 @@ +from datetime import datetime +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from .model_version import ModelVersion +from .parameter import Parameter + + +class BaselineParameterValue(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + parameter: Parameter + model_version: ModelVersion + value: float | int | str | bool | list | None = None + start_date: datetime + end_date: datetime | None = None diff --git a/build/lib/policyengine/models/baseline_variable.py b/build/lib/policyengine/models/baseline_variable.py new file mode 100644 index 00000000..b0e739b1 --- /dev/null +++ b/build/lib/policyengine/models/baseline_variable.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from .model_version import ModelVersion + + +class BaselineVariable(BaseModel): + id: str + model_version: ModelVersion + entity: str + label: str | None = None + description: str | None = None + data_type: type | None = None diff --git a/build/lib/policyengine/models/dataset.py b/build/lib/policyengine/models/dataset.py new file mode 100644 index 00000000..59dd626f --- /dev/null +++ b/build/lib/policyengine/models/dataset.py @@ -0,0 +1,18 @@ +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from .model import Model +from .versioned_dataset import VersionedDataset + + +class Dataset(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + name: str + description: str | None = None + version: str | None = None + versioned_dataset: VersionedDataset | None = None + year: int | None = None + data: Any | None = None + model: Model | None = None diff --git a/build/lib/policyengine/models/dynamic.py b/build/lib/policyengine/models/dynamic.py new file mode 100644 index 00000000..40cf364f --- /dev/null +++ b/build/lib/policyengine/models/dynamic.py @@ -0,0 +1,15 @@ +from collections.abc import Callable +from datetime import datetime +from uuid import uuid4 + +from pydantic import BaseModel, Field + + +class Dynamic(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + name: str + description: str | None = None + parameter_values: list[str] = [] + simulation_modifier: Callable | None = None + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) diff --git a/build/lib/policyengine/models/model.py b/build/lib/policyengine/models/model.py new file mode 100644 index 00000000..89cac9b8 --- /dev/null +++ b/build/lib/policyengine/models/model.py @@ -0,0 +1,126 @@ +from collections.abc import Callable +from datetime import datetime +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +if TYPE_CHECKING: + from .baseline_parameter_value import BaselineParameterValue + from .baseline_variable import BaselineVariable + from .parameter import Parameter + + +class Model(BaseModel): + id: str + name: str + description: str | None = None + simulation_function: Callable + + def create_seed_objects(self, model_version): + from policyengine_core.parameters import Parameter as CoreParameter + + from .baseline_parameter_value import BaselineParameterValue + from .baseline_variable import BaselineVariable + from .parameter import Parameter + + if self.id == "policyengine_uk": + from policyengine_uk.tax_benefit_system import system + elif self.id == "policyengine_us": + from policyengine_us.system import system + else: + raise ValueError("Unsupported model.") + + parameters = [] + baseline_parameter_values = [] + baseline_variables = [] + seen_parameter_ids = set() + + for parameter in system.parameters.get_descendants(): + # Skip if we've already processed this parameter ID + if parameter.name in seen_parameter_ids: + continue + seen_parameter_ids.add(parameter.name) + param = Parameter( + id=parameter.name, + description=parameter.description, + data_type=None, + model=self, + label=parameter.metadata.get("label"), + unit=parameter.metadata.get("unit"), + ) + parameters.append(param) + if isinstance(parameter, CoreParameter): + values = parameter.values_list[::-1] + param.data_type = type(values[-1].value) + for i in range(len(values)): + value_at_instant = values[i] + instant_str = safe_parse_instant_str( + value_at_instant.instant_str + ) + if i + 1 < len(values): + next_instant_str = safe_parse_instant_str( + values[i + 1].instant_str + ) + else: + next_instant_str = None + baseline_param_value = BaselineParameterValue( + parameter=param, + model_version=model_version, + value=value_at_instant.value, + start_date=instant_str, + end_date=next_instant_str, + ) + baseline_parameter_values.append(baseline_param_value) + + for variable in system.variables.values(): + baseline_variable = BaselineVariable( + id=variable.name, + model_version=model_version, + entity=variable.entity.key, + label=variable.label, + description=variable.documentation, + data_type=variable.value_type, + ) + baseline_variables.append(baseline_variable) + + return SeedObjects( + parameters=parameters, + baseline_parameter_values=baseline_parameter_values, + baseline_variables=baseline_variables, + ) + + +def safe_parse_instant_str(instant_str: str) -> datetime: + if instant_str == "0000-01-01": + return datetime(1, 1, 1) + else: + try: + return datetime.strptime(instant_str, "%Y-%m-%d") + except ValueError: + # Handle invalid dates like 2021-06-31 + # Try to parse year and month, then use last valid day + parts = instant_str.split("-") + if len(parts) == 3: + year = int(parts[0]) + month = int(parts[1]) + day = int(parts[2]) + + # Find the last valid day of the month + import calendar + + last_day = calendar.monthrange(year, month)[1] + if day > last_day: + print( + f"Warning: Invalid date {instant_str}, using {year}-{month:02d}-{last_day:02d}" + ) + return datetime(year, month, last_day) + + # If we can't parse it at all, print and raise + print(f"Error: Cannot parse date {instant_str}") + raise + + +class SeedObjects(BaseModel): + parameters: list["Parameter"] + baseline_parameter_values: list["BaselineParameterValue"] + baseline_variables: list["BaselineVariable"] diff --git a/build/lib/policyengine/models/model_version.py b/build/lib/policyengine/models/model_version.py new file mode 100644 index 00000000..18b542f8 --- /dev/null +++ b/build/lib/policyengine/models/model_version.py @@ -0,0 +1,14 @@ +from datetime import datetime +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from .model import Model + + +class ModelVersion(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + model: Model + version: str + description: str | None = None + created_at: datetime = Field(default_factory=datetime.now) diff --git a/build/lib/policyengine/models/parameter.py b/build/lib/policyengine/models/parameter.py new file mode 100644 index 00000000..ec7ef7be --- /dev/null +++ b/build/lib/policyengine/models/parameter.py @@ -0,0 +1,14 @@ +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from .model import Model + + +class Parameter(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + description: str | None = None + data_type: type | None = None + model: Model | None = None + label: str | None = None + unit: str | None = None diff --git a/build/lib/policyengine/models/parameter_value.py b/build/lib/policyengine/models/parameter_value.py new file mode 100644 index 00000000..a7867557 --- /dev/null +++ b/build/lib/policyengine/models/parameter_value.py @@ -0,0 +1,14 @@ +from datetime import datetime +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from .parameter import Parameter + + +class ParameterValue(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + parameter: Parameter + value: float | int | str | bool | list | None = None + start_date: datetime + end_date: datetime | None = None diff --git a/build/lib/policyengine/models/policy.py b/build/lib/policyengine/models/policy.py new file mode 100644 index 00000000..20587d85 --- /dev/null +++ b/build/lib/policyengine/models/policy.py @@ -0,0 +1,17 @@ +from collections.abc import Callable +from datetime import datetime +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from .parameter_value import ParameterValue + + +class Policy(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + name: str + description: str | None = None + parameter_values: list[ParameterValue] = [] + simulation_modifier: Callable | None = None + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) diff --git a/build/lib/policyengine/models/policyengine_uk.py b/build/lib/policyengine/models/policyengine_uk.py new file mode 100644 index 00000000..5b97ccfb --- /dev/null +++ b/build/lib/policyengine/models/policyengine_uk.py @@ -0,0 +1,113 @@ +import importlib.metadata + +import pandas as pd + +from ..models import Dataset, Dynamic, Model, ModelVersion, Policy + + +def run_policyengine_uk( + dataset: "Dataset", + policy: "Policy | None" = None, + dynamic: "Dynamic | None" = None, +) -> dict[str, "pd.DataFrame"]: + data: dict[str, pd.DataFrame] = dataset.data + + from policyengine_uk import Microsimulation + from policyengine_uk.data import UKSingleYearDataset + + pe_input_data = UKSingleYearDataset( + person=data["person"], + benunit=data["benunit"], + household=data["household"], + fiscal_year=dataset.year, + ) + + sim = Microsimulation(dataset=pe_input_data) + sim.default_calculation_period = dataset.year + + def simulation_modifier(sim: Microsimulation): + if policy is not None and len(policy.parameter_values) > 0: + for parameter_value in policy.parameter_values: + sim.tax_benefit_system.parameters.get_child( + parameter_value.parameter.id + ).update( + value=parameter_value.value, + start=parameter_value.start_date.strftime("%Y-%m-%d"), + stop=parameter_value.end_date.strftime("%Y-%m-%d") + if parameter_value.end_date + else None, + ) + + if dynamic is not None and len(dynamic.parameter_values) > 0: + for parameter_value in dynamic.parameter_values: + sim.tax_benefit_system.parameters.get_child( + parameter_value.parameter.id + ).update( + value=parameter_value.value, + start=parameter_value.start_date.strftime("%Y-%m-%d"), + stop=parameter_value.end_date.strftime("%Y-%m-%d") + if parameter_value.end_date + else None, + ) + + if dynamic is not None and dynamic.simulation_modifier is not None: + dynamic.simulation_modifier(sim) + if policy is not None and policy.simulation_modifier is not None: + policy.simulation_modifier(sim) + + simulation_modifier(sim) + + output_data = {} + + variable_blacklist = [ # TEMPORARY: we need to fix policyengine-uk to make these only take a long time with non-default parameters set to true. + "is_uc_entitled_baseline", + "income_elasticity_lsr", + "child_benefit_opts_out", + "housing_benefit_baseline_entitlement", + "baseline_ctc_entitlement", + "pre_budget_change_household_tax", + "pre_budget_change_household_net_income", + "is_on_cliff", + "marginal_tax_rate_on_capital_gains", + "relative_capital_gains_mtr_change", + "pre_budget_change_ons_equivalised_income_decile", + "substitution_elasticity", + "marginal_tax_rate", + "cliff_evaluated", + "cliff_gap", + "substitution_elasticity_lsr", + "relative_wage_change", + "relative_income_change", + "pre_budget_change_household_benefits", + ] + + for entity in ["person", "benunit", "household"]: + output_data[entity] = pd.DataFrame() + for variable in sim.tax_benefit_system.variables.values(): + correct_entity = variable.entity.key == entity + if variable.name in variable_blacklist: + continue + if variable.definition_period != "year": + continue + if correct_entity: + output_data[entity][variable.name] = sim.calculate( + variable.name + ).values + output_data[entity] = pd.DataFrame(output_data[entity]) + + return output_data + + +policyengine_uk_model = Model( + id="policyengine_uk", + name="PolicyEngine UK", + description="PolicyEngine's open-source tax-benefit microsimulation model.", + simulation_function=run_policyengine_uk, +) + +# Get policyengine-uk version + +policyengine_uk_latest_version = ModelVersion( + model=policyengine_uk_model, + version=importlib.metadata.distribution("policyengine_uk").version, +) diff --git a/build/lib/policyengine/models/policyengine_us.py b/build/lib/policyengine/models/policyengine_us.py new file mode 100644 index 00000000..9e2eeb7d --- /dev/null +++ b/build/lib/policyengine/models/policyengine_us.py @@ -0,0 +1,115 @@ +import importlib.metadata + +import pandas as pd + +from ..models import Dataset, Dynamic, Model, ModelVersion, Policy + + +def run_policyengine_us( + dataset: "Dataset", + policy: "Policy | None" = None, + dynamic: "Dynamic | None" = None, +) -> dict[str, "pd.DataFrame"]: + data: dict[str, pd.DataFrame] = dataset.data + + person_df = pd.DataFrame() + + for table_name, table in data.items(): + if table_name == "person": + for col in table.columns: + person_df[f"{col}__{dataset.year}"] = table[col].values + else: + foreign_key = data["person"][f"person_{table_name}_id"] + primary_key = data[table_name][f"{table_name}_id"] + + projected = table.set_index(primary_key).loc[foreign_key] + + for col in projected.columns: + person_df[f"{col}__{dataset.year}"] = projected[col].values + + from policyengine_us import Microsimulation + + sim = Microsimulation(dataset=person_df) + sim.default_calculation_period = dataset.year + + def simulation_modifier(sim: Microsimulation): + if policy is not None and len(policy.parameter_values) > 0: + for parameter_value in policy.parameter_values: + sim.tax_benefit_system.parameters.get_child( + parameter_value.parameter.id + ).update( + parameter_value.value, + start=parameter_value.start_date.strftime("%Y-%m-%d"), + stop=parameter_value.end_date.strftime("%Y-%m-%d") + if parameter_value.end_date + else None, + ) + + if dynamic is not None and len(dynamic.parameter_values) > 0: + for parameter_value in dynamic.parameter_values: + sim.tax_benefit_system.parameters.get_child( + parameter_value.parameter.id + ).update( + parameter_value.value, + start=parameter_value.start_date.strftime("%Y-%m-%d"), + stop=parameter_value.end_date.strftime("%Y-%m-%d") + if parameter_value.end_date + else None, + ) + + if dynamic is not None and dynamic.simulation_modifier is not None: + dynamic.simulation_modifier(sim) + if policy is not None and policy.simulation_modifier is not None: + policy.simulation_modifier(sim) + + simulation_modifier(sim) + + # Skip reforms for now + + output_data = {} + + variable_whitelist = [ + "household_net_income", + ] + + for variable in variable_whitelist: + sim.calculate(variable) + + for entity in [ + "person", + "marital_unit", + "family", + "tax_unit", + "spm_unit", + "household", + ]: + output_data[entity] = pd.DataFrame() + for variable in sim.tax_benefit_system.variables.values(): + correct_entity = variable.entity.key == entity + if str(dataset.year) not in list( + map(str, sim.get_holder(variable.name).get_known_periods()) + ): + continue + if variable.definition_period != "year": + continue + if not correct_entity: + continue + output_data[entity][variable.name] = sim.calculate(variable.name).values + + return output_data + + +policyengine_us_model = Model( + id="policyengine_us", + name="PolicyEngine US", + description="PolicyEngine's open-source tax-benefit microsimulation model.", + simulation_function=run_policyengine_us, +) + +# Get policyengine-uk version + + +policyengine_us_latest_version = ModelVersion( + model=policyengine_us_model, + version=importlib.metadata.distribution("policyengine_us").version, +) diff --git a/build/lib/policyengine/models/report.py b/build/lib/policyengine/models/report.py new file mode 100644 index 00000000..2ae0cd3b --- /dev/null +++ b/build/lib/policyengine/models/report.py @@ -0,0 +1,20 @@ +import uuid +from datetime import datetime +from typing import TYPE_CHECKING, ForwardRef + +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from .report_element import ReportElement + + +class Report(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + label: str + created_at: datetime | None = None + elements: list[ForwardRef("ReportElement")] = Field(default_factory=list) + + +# Import after class definition to avoid circular import +from .report_element import ReportElement +Report.model_rebuild() diff --git a/build/lib/policyengine/models/report_element.py b/build/lib/policyengine/models/report_element.py new file mode 100644 index 00000000..180ec26c --- /dev/null +++ b/build/lib/policyengine/models/report_element.py @@ -0,0 +1,38 @@ +import uuid +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, Field + + +class ReportElement(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + label: str + type: Literal["chart", "markdown"] + + # Data source + data_table: Literal["aggregates", "aggregate_changes"] | None = None # Which table to pull from + + # Chart configuration + chart_type: ( + Literal["bar", "line", "scatter", "area", "pie", "histogram"] | None + ) = None + x_axis_variable: str | None = None # Column name from the table + y_axis_variable: str | None = None # Column name from the table + group_by: str | None = None # Column to group/split series by + color_by: str | None = None # Column for color mapping + size_by: str | None = None # Column for size mapping (bubble charts) + + # Markdown specific + markdown_content: str | None = None + + # Metadata + report_id: str | None = None + user_id: str | None = None + model_version_id: str | None = None + position: int | None = None + visible: bool | None = True + custom_config: dict | None = None # Additional chart-specific config + report_element_metadata: dict | None = None # General metadata field for flexible data storage + created_at: datetime | None = None + updated_at: datetime | None = None diff --git a/build/lib/policyengine/models/simulation.py b/build/lib/policyengine/models/simulation.py new file mode 100644 index 00000000..a6ed7a5a --- /dev/null +++ b/build/lib/policyengine/models/simulation.py @@ -0,0 +1,35 @@ +from datetime import datetime +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from .dataset import Dataset +from .dynamic import Dynamic +from .model import Model +from .model_version import ModelVersion +from .policy import Policy + + +class Simulation(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + + policy: Policy | None = None + dynamic: Dynamic | None = None + dataset: Dataset + + model: Model + model_version: ModelVersion + result: Any | None = None + aggregates: list = Field(default_factory=list) # Will be list[Aggregate] but avoid circular import + + def run(self): + self.result = self.model.simulation_function( + dataset=self.dataset, + policy=self.policy, + dynamic=self.dynamic, + ) + self.updated_at = datetime.now() + return self.result diff --git a/build/lib/policyengine/models/user.py b/build/lib/policyengine/models/user.py new file mode 100644 index 00000000..dee924e1 --- /dev/null +++ b/build/lib/policyengine/models/user.py @@ -0,0 +1,14 @@ +import uuid +from datetime import datetime + +from pydantic import BaseModel, Field + + +class User(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + username: str + first_name: str | None = None + last_name: str | None = None + email: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None diff --git a/build/lib/policyengine/models/versioned_dataset.py b/build/lib/policyengine/models/versioned_dataset.py new file mode 100644 index 00000000..2f5e14f7 --- /dev/null +++ b/build/lib/policyengine/models/versioned_dataset.py @@ -0,0 +1,12 @@ +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from .model import Model + + +class VersionedDataset(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + name: str + description: str + model: Model | None = None diff --git a/build/lib/policyengine/utils/charts.py b/build/lib/policyengine/utils/charts.py new file mode 100644 index 00000000..0cee7048 --- /dev/null +++ b/build/lib/policyengine/utils/charts.py @@ -0,0 +1,286 @@ +"""Chart formatting utilities for PolicyEngine.""" + +import plotly.graph_objects as go +from IPython.display import HTML + +COLOUR_SCHEMES = { + "teal": { + "primary": "#319795", + "secondary": "#38B2AC", + "tertiary": "#4FD1C5", + "light": "#81E6D9", + "lighter": "#B2F5EA", + "lightest": "#E6FFFA", + "dark": "#2C7A7B", + "darker": "#285E61", + "darkest": "#234E52", + }, + "blue": { + "primary": "#0EA5E9", + "secondary": "#0284C7", + "tertiary": "#38BDF8", + "light": "#7DD3FC", + "lighter": "#BAE6FD", + "lightest": "#E0F2FE", + "dark": "#026AA2", + "darker": "#075985", + "darkest": "#0C4A6E", + }, + "gray": { + "primary": "#6B7280", + "secondary": "#9CA3AF", + "tertiary": "#D1D5DB", + "light": "#E2E8F0", + "lighter": "#F2F4F7", + "lightest": "#F9FAFB", + "dark": "#4B5563", + "darker": "#344054", + "darkest": "#101828", + }, +} + +DEFAULT_COLOURS = [ + COLOUR_SCHEMES["teal"]["primary"], + COLOUR_SCHEMES["blue"]["primary"], + COLOUR_SCHEMES["teal"]["secondary"], + COLOUR_SCHEMES["blue"]["secondary"], + COLOUR_SCHEMES["teal"]["tertiary"], + COLOUR_SCHEMES["blue"]["tertiary"], + COLOUR_SCHEMES["gray"]["dark"], + COLOUR_SCHEMES["teal"]["dark"], +] + + +def add_fonts() -> HTML: + """Return HTML to add Google Fonts for Roboto and Roboto Mono.""" + return HTML(""" + + + + """) + + +def format_figure( + fig: go.Figure, + title: str | None = None, + x_title: str | None = None, + y_title: str | None = None, + colour_scheme: str = "teal", + show_grid: bool = True, + show_legend: bool = True, + height: int | None = None, + width: int | None = None, +) -> go.Figure: + """Apply consistent formatting to a Plotly figure. + + Args: + fig: The Plotly figure to format + title: Optional title for the chart + x_title: Optional x-axis title + y_title: Optional y-axis title + colour_scheme: Colour scheme name (teal, blue, gray) + show_grid: Whether to show gridlines + show_legend: Whether to show the legend + height: Optional figure height in pixels + width: Optional figure width in pixels + + Returns: + The formatted figure + """ + + colours = COLOUR_SCHEMES.get(colour_scheme, COLOUR_SCHEMES["teal"]) + + # Update traces with colour scheme + for i, trace in enumerate(fig.data): + if hasattr(trace, "marker"): + trace.marker.color = DEFAULT_COLOURS[i % len(DEFAULT_COLOURS)] + if hasattr(trace, "line"): + trace.line.color = DEFAULT_COLOURS[i % len(DEFAULT_COLOURS)] + trace.line.width = 2 + + # Base layout settings + layout_updates = { + "font": { + "family": "Roboto, sans-serif", + "size": 14, + "color": COLOUR_SCHEMES["gray"]["darkest"], + }, + "plot_bgcolor": "white", + "paper_bgcolor": "white", + "showlegend": show_legend, + "hovermode": "x unified", + "hoverlabel": { + "bgcolor": "white", + "font": {"family": "Roboto Mono, monospace", "size": 12}, + "bordercolor": colours["light"], + }, + } + + # Add title if provided + if title: + layout_updates["title"] = { + "text": title, + "font": { + "family": "Roboto, sans-serif", + "size": 20, + "color": COLOUR_SCHEMES["gray"]["darkest"], + "weight": 500, + }, + } + + # Configure axes + axis_config = { + "showgrid": show_grid, + "gridcolor": COLOUR_SCHEMES["gray"]["light"], + "gridwidth": 1, + "zeroline": True, + "zerolinecolor": COLOUR_SCHEMES["gray"]["lighter"], + "zerolinewidth": 1, + "tickfont": { + "family": "Roboto Mono, monospace", + "size": 11, + "color": COLOUR_SCHEMES["gray"]["primary"], + }, + "titlefont": { + "family": "Roboto, sans-serif", + "size": 14, + "color": COLOUR_SCHEMES["gray"]["dark"], + }, + "linecolor": COLOUR_SCHEMES["gray"]["light"], + "linewidth": 1, + "showline": True, + "mirror": False, + } + + layout_updates["xaxis"] = axis_config.copy() + layout_updates["yaxis"] = axis_config.copy() + + if x_title: + layout_updates["xaxis"]["title"] = x_title + if y_title: + layout_updates["yaxis"]["title"] = y_title + + layout_updates["showlegend"] = len(fig.data) > 1 and show_legend + + # Set dimensions if provided + if height: + layout_updates["height"] = height + if width: + layout_updates["width"] = width + + fig.update_layout(**layout_updates) + + fig.update_xaxes(title_font_color=COLOUR_SCHEMES["gray"]["primary"]) + fig.update_yaxes(title_font_color=COLOUR_SCHEMES["gray"]["primary"]) + + # Add text annotations to bars in bar charts + if any(isinstance(trace, go.Bar) for trace in fig.data): + for trace in fig.data: + if isinstance(trace, go.Bar): + trace.texttemplate = "%{y:,.0f}" + trace.textposition = "outside" + trace.textfont = { + "family": "Roboto Mono, monospace", + "size": 11, + "color": COLOUR_SCHEMES["gray"]["primary"], + } + + return fig + + +def create_bar_chart( + data: dict[str, list], + x: str, + y: str, + title: str | None = None, + colour_scheme: str = "teal", + **kwargs, +) -> go.Figure: + """Create a formatted bar chart. + + Args: + data: Dictionary with data for the chart + x: Column name for x-axis + y: Column name for y-axis + title: Optional chart title + colour_scheme: Colour scheme to use + **kwargs: Additional arguments for format_figure + + Returns: + Formatted bar chart figure + """ + fig = go.Figure( + data=[ + go.Bar( + x=data[x], + y=data[y], + marker_color=COLOUR_SCHEMES[colour_scheme]["primary"], + marker_line_color=COLOUR_SCHEMES[colour_scheme]["dark"], + marker_line_width=1, + hovertemplate=f"{x}: " + + "%{x}
" + + f"{y}: " + + "%{y:,.0f}", + ) + ] + ) + + return format_figure( + fig, + title=title, + x_title=x, + y_title=y, + colour_scheme=colour_scheme, + **kwargs, + ) + + +def create_line_chart( + data: dict[str, list], + x: str, + y: str | list[str], + title: str | None = None, + colour_scheme: str = "teal", + **kwargs, +) -> go.Figure: + """Create a formatted line chart. + + Args: + data: Dictionary with data for the chart + x: Column name for x-axis + y: Column name(s) for y-axis (can be a list for multiple lines) + title: Optional chart title + colour_scheme: Colour scheme to use + **kwargs: Additional arguments for format_figure + + Returns: + Formatted line chart figure + """ + traces = [] + y_columns = y if isinstance(y, list) else [y] + + for i, y_col in enumerate(y_columns): + traces.append( + go.Scatter( + x=data[x], + y=data[y_col], + mode="lines+markers", + name=y_col, + line=dict( + color=DEFAULT_COLOURS[i % len(DEFAULT_COLOURS)], width=2 + ), + marker=dict(size=6), + hovertemplate=f"{y_col}: " + "%{y:,.0f}", + ) + ) + + fig = go.Figure(data=traces) + + return format_figure( + fig, + title=title, + x_title=x, + y_title=y_columns[0] if len(y_columns) == 1 else None, + colour_scheme=colour_scheme, + **kwargs, + ) diff --git a/build/lib/policyengine/utils/compress.py b/build/lib/policyengine/utils/compress.py new file mode 100644 index 00000000..19180e2a --- /dev/null +++ b/build/lib/policyengine/utils/compress.py @@ -0,0 +1,20 @@ +import pickle +from typing import Any + +import blosc + + +def compress_data(data: Any) -> bytes: + """Compress data using blosc after pickling.""" + pickled_data = pickle.dumps(data) + compressed_data = blosc.compress( + pickled_data, typesize=8, cname="zstd", clevel=9, shuffle=blosc.SHUFFLE + ) + return compressed_data + + +def decompress_data(compressed_data: bytes) -> Any: + """Decompress data using blosc and then unpickle.""" + decompressed_data = blosc.decompress(compressed_data) + data = pickle.loads(decompressed_data) + return data diff --git a/build/lib/policyengine/utils/datasets.py b/build/lib/policyengine/utils/datasets.py new file mode 100644 index 00000000..02090e11 --- /dev/null +++ b/build/lib/policyengine/utils/datasets.py @@ -0,0 +1,71 @@ +import pandas as pd + +from policyengine.models import Dataset + + +def create_uk_dataset( + dataset: str = "enhanced_frs_2023_24.h5", + year: int = 2029, +): + from policyengine_uk import Microsimulation + + from policyengine.models.policyengine_uk import policyengine_uk_model + + sim = Microsimulation( + dataset="hf://policyengine/policyengine-uk-data/" + dataset + ) + sim.default_calculation_period = year + + tables = { + "person": pd.DataFrame(sim.dataset[year].person), + "benunit": pd.DataFrame(sim.dataset[year].benunit), + "household": pd.DataFrame(sim.dataset[year].household), + } + + return Dataset( + id="uk", + name="UK", + description="A representative dataset for the UK, based on the Family Resources Survey.", + year=year, + model=policyengine_uk_model, + data=tables, + ) + + +def create_us_dataset( + dataset: str = "enhanced_cps_2024.h5", + year: int = 2024, +): + from policyengine_us import Microsimulation + + from policyengine.models.policyengine_us import policyengine_us_model + + sim = Microsimulation( + dataset="hf://policyengine/policyengine-us-data/" + dataset + ) + sim.default_calculation_period = year + + known_variables = sim.input_variables + + tables = { + "person": pd.DataFrame(), + "marital_unit": pd.DataFrame(), + "tax_unit": pd.DataFrame(), + "spm_unit": pd.DataFrame(), + "family": pd.DataFrame(), + "household": pd.DataFrame(), + } + + for variable in known_variables: + entity = sim.tax_benefit_system.variables[variable].entity.key + if variable in sim.tax_benefit_system.variables: + tables[entity][variable] = sim.calculate(variable) + + return Dataset( + id="us", + name="US", + description="A representative dataset for the US, based on the Current Population Survey.", + year=year, + model=policyengine_us_model, + data=tables, + ) diff --git a/docs/quickstart.ipynb b/docs/quickstart.ipynb index 46c7bdfd..f34568f8 100644 --- a/docs/quickstart.ipynb +++ b/docs/quickstart.ipynb @@ -66,12 +66,12 @@ "£500,000" ], "y": [ - 6604006.784030474, - 10307871.58979292, - 7152632.348702732, - 4284865.771267385, - 1718930.0846310211, - 1320096.830079406, + 6598054.774273342, + 10307676.305816924, + 7178099.41165927, + 4267341.442682308, + 1717525.567486669, + 1319705.9065854247, 326077.61111739336, 187608.23132836912, 63106.63353048405, @@ -1110,12 +1110,12 @@ "£500,000" ], "y": [ - 6604006.784030474, - 10307871.58979292, - 7152632.348702732, - 4284865.771267385, - 1718930.0846310211, - 1320096.830079406, + 6598054.774273342, + 10307676.305816924, + 7178099.41165927, + 4267341.442682308, + 1717525.567486669, + 1319705.9065854247, 326077.61111739336, 187608.23132836912, 63106.63353048405, @@ -1148,11 +1148,11 @@ "£500,000" ], "y": [ - 6148681.728599603, - 10315633.853035238, - 6911688.091380573, - 4489387.35271926, - 2005084.3325594077, + 6142729.718842472, + 10318841.418840772, + 6940624.716328709, + 4465347.839217288, + 2002931.6650648424, 1472352.5016867262, 341808.24952948757, 218180.35939976107, @@ -2224,7 +2224,7 @@ { "data": { "text/plain": [ - "Policy(id='9e558aa7-392f-4203-bf7d-769929e4f01e', name='Increase personal allowance to £20,000', description='A policy to increase the personal allowance for income tax to £20,000.', parameter_values=[ParameterValue(id='722765c7-5a97-4726-8156-a8f8fb26cb09', parameter=Parameter(id='gov.hmrc.income_tax.allowances.personal_allowance.amount', description=None, data_type=None, model=Model(id='policyengine_uk', name='PolicyEngine UK', description=\"PolicyEngine's open-source tax-benefit microsimulation model.\", simulation_function=), label=None, unit=None), value=20000, start_date=datetime.datetime(2029, 1, 1, 0, 0), end_date=None)], simulation_modifier=None, created_at=datetime.datetime(2025, 9, 28, 23, 42, 32, 430810), updated_at=datetime.datetime(2025, 9, 28, 23, 42, 32, 430842))" + "Policy(id='118195fb-4485-4616-bc9d-ed5ea540c300', name='Increase personal allowance to £20,000', description='A policy to increase the personal allowance for income tax to £20,000.', parameter_values=[ParameterValue(id='a8cfa8e2-f339-4896-a90f-ef470e365526', parameter=Parameter(id='gov.hmrc.income_tax.allowances.personal_allowance.amount', description=None, data_type=None, model=Model(id='policyengine_uk', name='PolicyEngine UK', description=\"PolicyEngine's open-source tax-benefit microsimulation model.\", simulation_function=), label=None, unit=None), value=20000, start_date=datetime.datetime(2029, 1, 1, 0, 0), end_date=None)], simulation_modifier=None, created_at=datetime.datetime(2025, 9, 29, 18, 21, 46, 643799), updated_at=datetime.datetime(2025, 9, 29, 18, 21, 46, 643805))" ] }, "execution_count": 5, diff --git a/src/policyengine/database/__init__.py b/src/policyengine/database/__init__.py index 5efcae3a..88e1a21b 100644 --- a/src/policyengine/database/__init__.py +++ b/src/policyengine/database/__init__.py @@ -28,6 +28,7 @@ from .report_table import ReportTable, report_table_link from .report_element_table import ReportElementTable, report_element_table_link from .aggregate import AggregateTable, aggregate_table_link +from .aggregate_change import AggregateChangeTable, aggregate_change_table_link __all__ = [ "Database", @@ -47,6 +48,7 @@ "ReportTable", "ReportElementTable", "AggregateTable", + "AggregateChangeTable", # Links "model_table_link", "model_version_table_link", @@ -62,4 +64,5 @@ "report_table_link", "report_element_table_link", "aggregate_table_link", + "aggregate_change_table_link", ] diff --git a/src/policyengine/database/aggregate.py b/src/policyengine/database/aggregate.py index 44c8aacd..c192605a 100644 --- a/src/policyengine/database/aggregate.py +++ b/src/policyengine/database/aggregate.py @@ -25,6 +25,9 @@ class AggregateTable(SQLModel, table=True): filter_variable_value: str | None = None filter_variable_leq: float | None = None filter_variable_geq: float | None = None + filter_variable_quantile_leq: float | None = None + filter_variable_quantile_geq: float | None = None + filter_variable_quantile_value: str | None = None aggregate_function: str reportelement_id: str | None = None value: float | None = None @@ -53,6 +56,9 @@ def convert_from_model(cls, model: Aggregate, database: "Database" = None) -> "A filter_variable_value=model.filter_variable_value, filter_variable_leq=model.filter_variable_leq, filter_variable_geq=model.filter_variable_geq, + filter_variable_quantile_leq=model.filter_variable_quantile_leq, + filter_variable_quantile_geq=model.filter_variable_quantile_geq, + filter_variable_quantile_value=model.filter_variable_quantile_value, aggregate_function=model.aggregate_function, reportelement_id=model.reportelement_id, value=model.value, @@ -89,6 +95,9 @@ def convert_to_model(self, database: "Database" = None) -> Aggregate: filter_variable_value=self.filter_variable_value, filter_variable_leq=self.filter_variable_leq, filter_variable_geq=self.filter_variable_geq, + filter_variable_quantile_leq=self.filter_variable_quantile_leq, + filter_variable_quantile_geq=self.filter_variable_quantile_geq, + filter_variable_quantile_value=self.filter_variable_quantile_value, aggregate_function=self.aggregate_function, reportelement_id=self.reportelement_id, value=self.value, diff --git a/src/policyengine/database/aggregate_change.py b/src/policyengine/database/aggregate_change.py new file mode 100644 index 00000000..e0a90ae4 --- /dev/null +++ b/src/policyengine/database/aggregate_change.py @@ -0,0 +1,128 @@ +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel + +from policyengine.database.link import TableLink +from policyengine.models.aggregate_change import AggregateChange + +if TYPE_CHECKING: + from .database import Database + + +class AggregateChangeTable(SQLModel, table=True): + __tablename__ = "aggregate_changes" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + baseline_simulation_id: str = Field( + foreign_key="simulations.id", ondelete="CASCADE" + ) + comparison_simulation_id: str = Field( + foreign_key="simulations.id", ondelete="CASCADE" + ) + entity: str + variable_name: str + year: int | None = None + filter_variable_name: str | None = None + filter_variable_value: str | None = None + filter_variable_leq: float | None = None + filter_variable_geq: float | None = None + filter_variable_quantile_leq: float | None = None + filter_variable_quantile_geq: float | None = None + filter_variable_quantile_value: str | None = None + aggregate_function: str + reportelement_id: str | None = None + + baseline_value: float | None = None + comparison_value: float | None = None + change: float | None = None + relative_change: float | None = None + + @classmethod + def convert_from_model(cls, model: AggregateChange, database: "Database" = None) -> "AggregateChangeTable": + """Convert an AggregateChange instance to an AggregateChangeTable instance. + + Args: + model: The AggregateChange instance to convert + database: The database instance for persisting the simulations if needed + + Returns: + An AggregateChangeTable instance + """ + return cls( + id=model.id, + baseline_simulation_id=model.baseline_simulation.id if model.baseline_simulation else None, + comparison_simulation_id=model.comparison_simulation.id if model.comparison_simulation else None, + entity=model.entity, + variable_name=model.variable_name, + year=model.year, + filter_variable_name=model.filter_variable_name, + filter_variable_value=model.filter_variable_value, + filter_variable_leq=model.filter_variable_leq, + filter_variable_geq=model.filter_variable_geq, + filter_variable_quantile_leq=model.filter_variable_quantile_leq, + filter_variable_quantile_geq=model.filter_variable_quantile_geq, + filter_variable_quantile_value=model.filter_variable_quantile_value, + aggregate_function=model.aggregate_function, + reportelement_id=model.reportelement_id, + baseline_value=model.baseline_value, + comparison_value=model.comparison_value, + change=model.change, + relative_change=model.relative_change, + ) + + def convert_to_model(self, database: "Database" = None) -> AggregateChange: + """Convert this AggregateChangeTable instance to an AggregateChange instance. + + Args: + database: The database instance for resolving simulation foreign keys + + Returns: + An AggregateChange instance + """ + from .simulation_table import SimulationTable + from sqlmodel import select + + # Resolve the simulation foreign keys + baseline_simulation = None + comparison_simulation = None + + if database: + if self.baseline_simulation_id: + sim_table = database.session.exec( + select(SimulationTable).where(SimulationTable.id == self.baseline_simulation_id) + ).first() + if sim_table: + baseline_simulation = sim_table.convert_to_model(database) + + if self.comparison_simulation_id: + sim_table = database.session.exec( + select(SimulationTable).where(SimulationTable.id == self.comparison_simulation_id) + ).first() + if sim_table: + comparison_simulation = sim_table.convert_to_model(database) + + return AggregateChange( + id=self.id, + baseline_simulation=baseline_simulation, + comparison_simulation=comparison_simulation, + entity=self.entity, + variable_name=self.variable_name, + year=self.year, + filter_variable_name=self.filter_variable_name, + filter_variable_value=self.filter_variable_value, + filter_variable_leq=self.filter_variable_leq, + filter_variable_geq=self.filter_variable_geq, + aggregate_function=self.aggregate_function, + reportelement_id=self.reportelement_id, + baseline_value=self.baseline_value, + comparison_value=self.comparison_value, + change=self.change, + relative_change=self.relative_change, + ) + + +aggregate_change_table_link = TableLink( + model_cls=AggregateChange, + table_cls=AggregateChangeTable, +) \ No newline at end of file diff --git a/src/policyengine/database/report_element_table.py b/src/policyengine/database/report_element_table.py index 86e57f9b..cc69e83e 100644 --- a/src/policyengine/database/report_element_table.py +++ b/src/policyengine/database/report_element_table.py @@ -22,7 +22,7 @@ class ReportElementTable(SQLModel, table=True, extend_existing=True): type: str = Field(nullable=False) # "chart" or "markdown" # Data source - data_table: str | None = Field(default=None) # "aggregates" + data_table: str | None = Field(default=None) # "aggregates" or "aggregate_changes" # Chart configuration chart_type: str | None = Field( diff --git a/src/policyengine/models/__init__.py b/src/policyengine/models/__init__.py index 652d46cd..de5fd8c9 100644 --- a/src/policyengine/models/__init__.py +++ b/src/policyengine/models/__init__.py @@ -1,5 +1,6 @@ from .aggregate import Aggregate as Aggregate from .aggregate import AggregateType as AggregateType +from .aggregate_change import AggregateChange as AggregateChange from .baseline_parameter_value import ( BaselineParameterValue as BaselineParameterValue, ) @@ -31,6 +32,8 @@ # Rebuild models to handle circular references from .aggregate import Aggregate +from .aggregate_change import AggregateChange from .simulation import Simulation Aggregate.model_rebuild() +AggregateChange.model_rebuild() Simulation.model_rebuild() diff --git a/src/policyengine/models/aggregate.py b/src/policyengine/models/aggregate.py index 5ae67569..031cad87 100644 --- a/src/policyengine/models/aggregate.py +++ b/src/policyengine/models/aggregate.py @@ -13,6 +13,7 @@ class AggregateType(str, Enum): SUM = "sum" MEAN = "mean" + MEDIAN = "median" COUNT = "count" @@ -27,7 +28,7 @@ class Aggregate(BaseModel): filter_variable_leq: float | None = None filter_variable_geq: float | None = None aggregate_function: Literal[ - AggregateType.SUM, AggregateType.MEAN, AggregateType.COUNT + AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT ] reportelement_id: str | None = None @@ -121,6 +122,8 @@ def _process_simulation_aggregates( agg.value = float(df[agg.variable_name].sum()) elif agg.aggregate_function == AggregateType.MEAN: agg.value = float(df[agg.variable_name].mean()) + elif agg.aggregate_function == AggregateType.MEDIAN: + agg.value = float(df[agg.variable_name].median()) elif agg.aggregate_function == AggregateType.COUNT: agg.value = float((df[agg.variable_name] > 0).sum()) diff --git a/src/policyengine/models/aggregate_change.py b/src/policyengine/models/aggregate_change.py new file mode 100644 index 00000000..e0a400df --- /dev/null +++ b/src/policyengine/models/aggregate_change.py @@ -0,0 +1,143 @@ +from enum import Enum +from typing import TYPE_CHECKING, Literal +from uuid import uuid4 + +import pandas as pd +from microdf import MicroDataFrame +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from policyengine.models import Simulation + + +class AggregateType(str, Enum): + SUM = "sum" + MEAN = "mean" + MEDIAN = "median" + COUNT = "count" + + +class AggregateChange(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + baseline_simulation: "Simulation | None" = None + comparison_simulation: "Simulation | None" = None + entity: str + variable_name: str + year: int | None = None + filter_variable_name: str | None = None + filter_variable_value: str | None = None + filter_variable_leq: float | None = None + filter_variable_geq: float | None = None + aggregate_function: Literal[ + AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT + ] + reportelement_id: str | None = None + + baseline_value: float | None = None + comparison_value: float | None = None + change: float | None = None + relative_change: float | None = None + + @staticmethod + def run(aggregate_changes: list["AggregateChange"]) -> list["AggregateChange"]: + """Process aggregate changes, handling multiple simulation pairs.""" + results = [] + + for agg_change in aggregate_changes: + if agg_change.baseline_simulation is None: + raise ValueError("AggregateChange has no baseline simulation attached") + if agg_change.comparison_simulation is None: + raise ValueError("AggregateChange has no comparison simulation attached") + + # Compute baseline value + baseline_value = AggregateChange._compute_single_aggregate( + agg_change, agg_change.baseline_simulation + ) + + # Compute comparison value + comparison_value = AggregateChange._compute_single_aggregate( + agg_change, agg_change.comparison_simulation + ) + + # Compute changes + agg_change.baseline_value = baseline_value + agg_change.comparison_value = comparison_value + agg_change.change = comparison_value - baseline_value + + # Compute relative change (avoiding division by zero) + if baseline_value != 0: + agg_change.relative_change = (comparison_value - baseline_value) / abs(baseline_value) + else: + agg_change.relative_change = None if comparison_value == 0 else float('inf') + + results.append(agg_change) + + return results + + @staticmethod + def _compute_single_aggregate( + agg_change: "AggregateChange", simulation: "Simulation" + ) -> float: + """Compute aggregate value for a single simulation.""" + tables = simulation.result + # Copy tables to ensure we don't modify original dataframes + tables = {k: v.copy() for k, v in tables.items()} + + for table in tables: + tables[table] = pd.DataFrame(tables[table]) + weight_col = f"{table}_weight" + if weight_col in tables[table].columns: + tables[table] = MicroDataFrame( + tables[table], weights=weight_col + ) + + if agg_change.entity not in tables: + raise ValueError( + f"Entity {agg_change.entity} not found in simulation results" + ) + + table = tables[agg_change.entity] + + if agg_change.variable_name not in table.columns: + raise ValueError( + f"Variable {agg_change.variable_name} not found in entity {agg_change.entity}" + ) + + df = table + + if agg_change.year is None: + agg_change.year = simulation.dataset.year + + # Apply filters + if agg_change.filter_variable_name is not None: + if agg_change.filter_variable_name not in df.columns: + raise ValueError( + f"Filter variable {agg_change.filter_variable_name} not found in entity {agg_change.entity}" + ) + if agg_change.filter_variable_value is not None: + df = df[ + df[agg_change.filter_variable_name] + == agg_change.filter_variable_value + ] + if agg_change.filter_variable_leq is not None: + df = df[ + df[agg_change.filter_variable_name] <= agg_change.filter_variable_leq + ] + if agg_change.filter_variable_geq is not None: + df = df[ + df[agg_change.filter_variable_name] >= agg_change.filter_variable_geq + ] + + # Compute aggregate + if agg_change.aggregate_function == AggregateType.SUM: + value = float(df[agg_change.variable_name].sum()) + elif agg_change.aggregate_function == AggregateType.MEAN: + value = float(df[agg_change.variable_name].mean()) + elif agg_change.aggregate_function == AggregateType.MEDIAN: + value = float(df[agg_change.variable_name].median()) + elif agg_change.aggregate_function == AggregateType.COUNT: + value = float((df[agg_change.variable_name] > 0).sum()) + else: + raise ValueError(f"Unknown aggregate function: {agg_change.aggregate_function}") + + return value \ No newline at end of file diff --git a/src/policyengine/models/report_element.py b/src/policyengine/models/report_element.py index ac7fcfcb..180ec26c 100644 --- a/src/policyengine/models/report_element.py +++ b/src/policyengine/models/report_element.py @@ -11,7 +11,7 @@ class ReportElement(BaseModel): type: Literal["chart", "markdown"] # Data source - data_table: Literal["aggregates"] | None = None # Which table to pull from + data_table: Literal["aggregates", "aggregate_changes"] | None = None # Which table to pull from # Chart configuration chart_type: ( From f49e8351a9f4250b5733eee70f9c5e94b85c38d2 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 29 Sep 2025 18:24:36 +0100 Subject: [PATCH 04/35] Update --- src/policyengine/database/aggregate_change.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/policyengine/database/aggregate_change.py b/src/policyengine/database/aggregate_change.py index e0a90ae4..47fbda05 100644 --- a/src/policyengine/database/aggregate_change.py +++ b/src/policyengine/database/aggregate_change.py @@ -113,6 +113,9 @@ def convert_to_model(self, database: "Database" = None) -> AggregateChange: filter_variable_value=self.filter_variable_value, filter_variable_leq=self.filter_variable_leq, filter_variable_geq=self.filter_variable_geq, + filter_variable_quantile_leq=self.filter_variable_quantile_leq, + filter_variable_quantile_geq=self.filter_variable_quantile_geq, + filter_variable_quantile_value=self.filter_variable_quantile_value, aggregate_function=self.aggregate_function, reportelement_id=self.reportelement_id, baseline_value=self.baseline_value, From 3fa2d49f479cb93ce20817ae022eba0f32b39ad4 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 29 Sep 2025 23:45:37 +0100 Subject: [PATCH 05/35] Update --- src/policyengine/models/aggregate.py | 83 +++- src/policyengine/models/aggregate_change.py | 486 ++++++++++++++++++-- 2 files changed, 511 insertions(+), 58 deletions(-) diff --git a/src/policyengine/models/aggregate.py b/src/policyengine/models/aggregate.py index 031cad87..bb580c94 100644 --- a/src/policyengine/models/aggregate.py +++ b/src/policyengine/models/aggregate.py @@ -20,13 +20,16 @@ class AggregateType(str, Enum): class Aggregate(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) simulation: "Simulation | None" = None - entity: str + entity: str | None = None variable_name: str year: int | None = None filter_variable_name: str | None = None filter_variable_value: str | None = None filter_variable_leq: float | None = None filter_variable_geq: float | None = None + filter_variable_quantile_leq: float | None = None + filter_variable_quantile_geq: float | None = None + filter_variable_quantile_value: str | None = None aggregate_function: Literal[ AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT ] @@ -34,6 +37,29 @@ class Aggregate(BaseModel): value: float | None = None + @staticmethod + def _infer_entity(variable_name: str, tables: dict) -> str: + """Infer entity from variable name by checking which table contains it.""" + for entity, table in tables.items(): + if variable_name in table.columns: + return entity + raise ValueError(f"Variable {variable_name} not found in any entity table") + + @staticmethod + def _get_entity_link_columns() -> dict: + """Return mapping of entity relationships for common PolicyEngine models.""" + return { + # person -> group entity links + "person": { + "benunit": "person_benunit_id", + "household": "person_household_id", + "family": "person_family_id", + "tax_unit": "person_tax_unit_id", + "spm_unit": "person_spm_unit_id", + }, + # Group entities don't have direct upward links typically + } + @staticmethod def run(aggregates: list["Aggregate"]) -> list["Aggregate"]: """Process aggregates, handling multiple simulations if necessary.""" @@ -83,6 +109,10 @@ def _process_simulation_aggregates( ) for agg in aggregates: + # Infer entity if not provided + if agg.entity is None: + agg.entity = Aggregate._infer_entity(agg.variable_name, tables) + if agg.entity not in tables: raise ValueError( f"Entity {agg.entity} not found in simulation results" @@ -104,6 +134,8 @@ def _process_simulation_aggregates( raise ValueError( f"Filter variable {agg.filter_variable_name} not found in entity {agg.entity}" ) + + # Apply value/range filters if agg.filter_variable_value is not None: df = df[ df[agg.filter_variable_name] @@ -118,14 +150,47 @@ def _process_simulation_aggregates( df[agg.filter_variable_name] >= agg.filter_variable_geq ] - if agg.aggregate_function == AggregateType.SUM: - agg.value = float(df[agg.variable_name].sum()) - elif agg.aggregate_function == AggregateType.MEAN: - agg.value = float(df[agg.variable_name].mean()) - elif agg.aggregate_function == AggregateType.MEDIAN: - agg.value = float(df[agg.variable_name].median()) - elif agg.aggregate_function == AggregateType.COUNT: - agg.value = float((df[agg.variable_name] > 0).sum()) + # Apply quantile filters if specified + if any([agg.filter_variable_quantile_leq, + agg.filter_variable_quantile_geq, agg.filter_variable_quantile_value]): + + if agg.filter_variable_quantile_leq is not None: + # Filter to values <= specified quantile + threshold = df[agg.filter_variable_name].quantile(agg.filter_variable_quantile_leq) + df = df[df[agg.filter_variable_name] <= threshold] + + if agg.filter_variable_quantile_geq is not None: + # Filter to values >= specified quantile + threshold = df[agg.filter_variable_name].quantile(agg.filter_variable_quantile_geq) + df = df[df[agg.filter_variable_name] >= threshold] + + if agg.filter_variable_quantile_value is not None: + # Parse quantile value like "top_10%" or "bottom_20%" + if "top" in agg.filter_variable_quantile_value.lower(): + pct = float(agg.filter_variable_quantile_value.lower().replace("top_", "").replace("%", "")) / 100 + threshold = df[agg.filter_variable_name].quantile(1 - pct) + df = df[df[agg.filter_variable_name] >= threshold] + elif "bottom" in agg.filter_variable_quantile_value.lower(): + pct = float(agg.filter_variable_quantile_value.lower().replace("bottom_", "").replace("%", "")) / 100 + threshold = df[agg.filter_variable_name].quantile(pct) + df = df[df[agg.filter_variable_name] <= threshold] + + # Check if we have any data left after filtering + if len(df) == 0: + agg.value = 0.0 + else: + try: + if agg.aggregate_function == AggregateType.SUM: + agg.value = float(df[agg.variable_name].sum()) + elif agg.aggregate_function == AggregateType.MEAN: + agg.value = float(df[agg.variable_name].mean()) + elif agg.aggregate_function == AggregateType.MEDIAN: + agg.value = float(df[agg.variable_name].median()) + elif agg.aggregate_function == AggregateType.COUNT: + agg.value = float((df[agg.variable_name] > 0).sum()) + except (ZeroDivisionError, ValueError): + # Handle cases where weights sum to zero + agg.value = 0.0 results.append(agg) diff --git a/src/policyengine/models/aggregate_change.py b/src/policyengine/models/aggregate_change.py index e0a400df..6c0b9ccb 100644 --- a/src/policyengine/models/aggregate_change.py +++ b/src/policyengine/models/aggregate_change.py @@ -5,6 +5,7 @@ import pandas as pd from microdf import MicroDataFrame from pydantic import BaseModel, Field +import time if TYPE_CHECKING: from policyengine.models import Simulation @@ -21,13 +22,16 @@ class AggregateChange(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) baseline_simulation: "Simulation | None" = None comparison_simulation: "Simulation | None" = None - entity: str + entity: str | None = None variable_name: str year: int | None = None filter_variable_name: str | None = None filter_variable_value: str | None = None filter_variable_leq: float | None = None filter_variable_geq: float | None = None + filter_variable_quantile_leq: float | None = None + filter_variable_quantile_geq: float | None = None + filter_variable_quantile_value: str | None = None aggregate_function: Literal[ AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT ] @@ -38,47 +42,436 @@ class AggregateChange(BaseModel): change: float | None = None relative_change: float | None = None + @staticmethod + def _infer_entity(variable_name: str, filter_variable_name: str | None, tables: dict) -> str: + """Infer entity from the target variable (not the filter variable). + + The entity represents what level we're aggregating at, determined by the target variable. + Filters can be cross-entity and will be mapped if needed. + """ + # Find entity of target variable + for entity, table in tables.items(): + if variable_name in table.columns: + return entity + + raise ValueError(f"Variable {variable_name} not found in any entity table") + + @staticmethod + def _get_entity_link_columns() -> dict: + """Return mapping of entity relationships for common PolicyEngine models.""" + return { + # person -> group entity links (copy values down) + "person": { + "benunit": "person_benunit_id", + "household": "person_household_id", + "family": "person_family_id", + "tax_unit": "person_tax_unit_id", + "spm_unit": "person_spm_unit_id", + }, + } + + @staticmethod + def _map_variable_across_entities( + df: pd.DataFrame, + variable_name: str, + source_entity: str, + target_entity: str, + tables: dict + ) -> pd.Series: + """Map a variable from source entity to target entity level.""" + links = AggregateChange._get_entity_link_columns() + + # Group to person: copy group values to persons using link column + if source_entity != "person" and target_entity == "person": + link_col = links.get("person", {}).get(source_entity) + if link_col is None: + raise ValueError(f"No known link from person to {source_entity}") + + if link_col not in tables["person"].columns: + raise ValueError(f"Link column {link_col} not found in person table") + + # Create mapping: group position (0-based index) -> value + # Most PolicyEngine models have entities numbered 0, 1, 2, ... + group_values = df[variable_name].values + + # Map to person level using the link column + person_table = tables["person"] + person_group_ids = person_table[link_col].values + + # Map each person to their group's value + result = pd.Series([group_values[int(gid)] if int(gid) < len(group_values) else 0 + for gid in person_group_ids], index=person_table.index) + return result + + # Person to group: sum persons' values to group level + elif source_entity == "person" and target_entity != "person": + link_col = links.get("person", {}).get(target_entity) + if link_col is None: + raise ValueError(f"No known link from person to {target_entity}") + + if link_col not in df.columns: + raise ValueError(f"Link column {link_col} not found in person table") + + # Sum by group - need to align with group table length + grouped = df.groupby(link_col)[variable_name].sum() + + # Create a series aligned with the group table + group_table = tables[target_entity] + result = pd.Series([grouped.get(i, 0) for i in range(len(group_table))], + index=group_table.index) + return result + + # Group to group: try via person as intermediary + elif source_entity != "person" and target_entity != "person": + # Map source -> person -> target + person_values = AggregateChange._map_variable_across_entities( + df, variable_name, source_entity, "person", tables + ) + # Create temp dataframe with person values + temp_person_df = tables["person"].copy() + temp_person_df[variable_name] = person_values + + return AggregateChange._map_variable_across_entities( + temp_person_df, variable_name, "person", target_entity, tables + ) + + else: + # Same entity - shouldn't happen but return as-is + return df[variable_name] + + @staticmethod + def _prepare_tables(simulation: "Simulation") -> dict: + """Prepare dataframes from simulation result once.""" + tables = simulation.result + tables = {k: v.copy() for k, v in tables.items()} + + for table in tables: + tables[table] = pd.DataFrame(tables[table]) + weight_col = f"{table}_weight" + if weight_col in tables[table].columns: + tables[table] = MicroDataFrame( + tables[table], weights=weight_col + ) + + return tables + @staticmethod def run(aggregate_changes: list["AggregateChange"]) -> list["AggregateChange"]: - """Process aggregate changes, handling multiple simulation pairs.""" - results = [] + """Process aggregate changes, batching those with the same simulation pair.""" + start_time = time.time() + print(f"[PERFORMANCE] AggregateChange.run starting with {len(aggregate_changes)} items") + # Group aggregate changes by simulation pair for batch processing + from collections import defaultdict + grouped = defaultdict(list) for agg_change in aggregate_changes: if agg_change.baseline_simulation is None: raise ValueError("AggregateChange has no baseline simulation attached") if agg_change.comparison_simulation is None: raise ValueError("AggregateChange has no comparison simulation attached") - # Compute baseline value - baseline_value = AggregateChange._compute_single_aggregate( - agg_change, agg_change.baseline_simulation + key = (agg_change.baseline_simulation.id, agg_change.comparison_simulation.id) + grouped[key].append(agg_change) + + print(f"[PERFORMANCE] Grouped {len(aggregate_changes)} items into {len(grouped)} simulation pairs") + + results = [] + + for (baseline_id, comparison_id), group in grouped.items(): + group_start = time.time() + print(f"[PERFORMANCE] Processing batch of {len(group)} items for sim pair {baseline_id[:8]}...{comparison_id[:8]}") + + # Get simulation objects once for the group + baseline_sim = group[0].baseline_simulation + comparison_sim = group[0].comparison_simulation + + # Pre-compute simulation dataframes once per batch + baseline_tables = AggregateChange._prepare_tables(baseline_sim) + comparison_tables = AggregateChange._prepare_tables(comparison_sim) + + prep_time = time.time() + print(f"[PERFORMANCE] Table preparation took {prep_time - group_start:.3f} seconds") + + # Process each item in the group + for idx, agg_change in enumerate(group): + item_start = time.time() + + # Infer entity if not provided (use filter variable entity if available) + if agg_change.entity is None: + agg_change.entity = AggregateChange._infer_entity( + agg_change.variable_name, + agg_change.filter_variable_name, + baseline_tables + ) + + # Compute filter mask on baseline + filter_mask = AggregateChange._get_filter_mask_from_tables( + agg_change, baseline_tables + ) + + # Compute baseline value + baseline_value = AggregateChange._compute_single_aggregate_from_tables( + agg_change, baseline_tables, filter_mask + ) + + # Compute comparison value using same filter + comparison_value = AggregateChange._compute_single_aggregate_from_tables( + agg_change, comparison_tables, filter_mask + ) + + # Compute changes + agg_change.baseline_value = baseline_value + agg_change.comparison_value = comparison_value + agg_change.change = comparison_value - baseline_value + + # Compute relative change (avoiding division by zero) + if baseline_value != 0: + agg_change.relative_change = (comparison_value - baseline_value) / abs(baseline_value) + else: + agg_change.relative_change = None if comparison_value == 0 else float('inf') + + results.append(agg_change) + + group_time = time.time() + print(f"[PERFORMANCE] Batch processing took {group_time - group_start:.3f} seconds ({(group_time - group_start) / len(group):.3f}s per item)") + + total_time = time.time() + print(f"[PERFORMANCE] AggregateChange.run completed in {total_time - start_time:.2f} seconds") + return results + + @staticmethod + def _get_filter_mask_from_tables( + agg_change: "AggregateChange", tables: dict + ) -> pd.Series | None: + """Get filter mask from pre-prepared tables, handling cross-entity filters.""" + if agg_change.filter_variable_name is None: + return None + + if agg_change.entity not in tables: + raise ValueError( + f"Entity {agg_change.entity} not found in simulation results" + ) + + # Find which entity contains the filter variable + filter_entity = None + for entity, table in tables.items(): + if agg_change.filter_variable_name in table.columns: + filter_entity = entity + break + + if filter_entity is None: + raise ValueError( + f"Filter variable {agg_change.filter_variable_name} not found in any entity" ) - # Compute comparison value - comparison_value = AggregateChange._compute_single_aggregate( - agg_change, agg_change.comparison_simulation + # Get the dataframe for filtering + if filter_entity == agg_change.entity: + # Same entity - use directly + df = tables[agg_change.entity] + filter_series = df[agg_change.filter_variable_name] + else: + # Different entity - need to map filter variable to target entity + filter_df = tables[filter_entity] + mapped_filter = AggregateChange._map_variable_across_entities( + filter_df, + agg_change.filter_variable_name, + filter_entity, + agg_change.entity, + tables ) + df = tables[agg_change.entity] + filter_series = mapped_filter - # Compute changes - agg_change.baseline_value = baseline_value - agg_change.comparison_value = comparison_value - agg_change.change = comparison_value - baseline_value + mask = pd.Series([True] * len(df), index=df.index) - # Compute relative change (avoiding division by zero) - if baseline_value != 0: - agg_change.relative_change = (comparison_value - baseline_value) / abs(baseline_value) + if agg_change.filter_variable_value is not None: + mask &= filter_series == agg_change.filter_variable_value + + if agg_change.filter_variable_leq is not None: + mask &= filter_series <= agg_change.filter_variable_leq + + if agg_change.filter_variable_geq is not None: + mask &= filter_series >= agg_change.filter_variable_geq + + if any([agg_change.filter_variable_quantile_leq, + agg_change.filter_variable_quantile_geq, agg_change.filter_variable_quantile_value]): + + if agg_change.filter_variable_quantile_leq is not None: + threshold = filter_series.quantile(agg_change.filter_variable_quantile_leq) + mask &= filter_series <= threshold + + if agg_change.filter_variable_quantile_geq is not None: + threshold = filter_series.quantile(agg_change.filter_variable_quantile_geq) + mask &= filter_series >= threshold + + if agg_change.filter_variable_quantile_value is not None: + if "top" in agg_change.filter_variable_quantile_value.lower(): + pct = float(agg_change.filter_variable_quantile_value.lower().replace("top_", "").replace("%", "")) / 100 + threshold = filter_series.quantile(1 - pct) + mask &= filter_series >= threshold + elif "bottom" in agg_change.filter_variable_quantile_value.lower(): + pct = float(agg_change.filter_variable_quantile_value.lower().replace("bottom_", "").replace("%", "")) / 100 + threshold = filter_series.quantile(pct) + mask &= filter_series <= threshold + + return mask + + @staticmethod + def _compute_single_aggregate_from_tables( + agg_change: "AggregateChange", + tables: dict, + filter_mask: pd.Series | None = None + ) -> float: + """Compute aggregate value from pre-prepared tables.""" + if agg_change.entity not in tables: + raise ValueError( + f"Entity {agg_change.entity} not found in simulation results" + ) + + # Check if variable is in the target entity + target_entity = agg_change.entity + variable_entity = None + + # Find which entity contains the variable + for entity, table in tables.items(): + if agg_change.variable_name in table.columns: + variable_entity = entity + break + + if variable_entity is None: + raise ValueError( + f"Variable {agg_change.variable_name} not found in any entity" + ) + + # If variable is in a different entity than the filter, we need to map + if variable_entity != target_entity: + # Get the variable data from its native entity + source_table = tables[variable_entity] + + # Map it to the target entity level + try: + mapped_series = AggregateChange._map_variable_across_entities( + source_table, + agg_change.variable_name, + variable_entity, + target_entity, + tables + ) + # Create a temporary dataframe with the mapped variable + table = tables[target_entity].copy() + table[agg_change.variable_name] = mapped_series + except ValueError as e: + # If mapping fails, raise informative error + raise ValueError( + f"Variable {agg_change.variable_name} is in {variable_entity} entity, " + f"but filters are at {target_entity} level. Cannot map between these entities: {str(e)}" + ) + else: + table = tables[agg_change.entity] + + df = table + + if filter_mask is not None: + df = df[filter_mask] + + if len(df) == 0: + return 0.0 + + try: + if agg_change.aggregate_function == AggregateType.SUM: + value = float(df[agg_change.variable_name].sum()) + elif agg_change.aggregate_function == AggregateType.MEAN: + value = float(df[agg_change.variable_name].mean()) + elif agg_change.aggregate_function == AggregateType.MEDIAN: + value = float(df[agg_change.variable_name].median()) + elif agg_change.aggregate_function == AggregateType.COUNT: + value = float((df[agg_change.variable_name] > 0).sum()) else: - agg_change.relative_change = None if comparison_value == 0 else float('inf') + raise ValueError(f"Unknown aggregate function: {agg_change.aggregate_function}") + except (ZeroDivisionError, ValueError) as e: + return 0.0 - results.append(agg_change) + return value - return results + @staticmethod + def _get_filter_mask( + agg_change: "AggregateChange", simulation: "Simulation" + ) -> pd.Series | None: + """Get filter mask based on baseline simulation values.""" + if agg_change.filter_variable_name is None: + return None # No filtering needed + + tables = simulation.result + tables = {k: v.copy() for k, v in tables.items()} + + for table in tables: + tables[table] = pd.DataFrame(tables[table]) + weight_col = f"{table}_weight" + if weight_col in tables[table].columns: + tables[table] = MicroDataFrame( + tables[table], weights=weight_col + ) + + if agg_change.entity not in tables: + raise ValueError( + f"Entity {agg_change.entity} not found in simulation results" + ) + + df = tables[agg_change.entity] + + if agg_change.filter_variable_name not in df.columns: + raise ValueError( + f"Filter variable {agg_change.filter_variable_name} not found in entity {agg_change.entity}" + ) + + # Create filter mask based on baseline values + mask = pd.Series([True] * len(df), index=df.index) + + # Apply value/range filters + if agg_change.filter_variable_value is not None: + mask &= df[agg_change.filter_variable_name] == agg_change.filter_variable_value + + if agg_change.filter_variable_leq is not None: + mask &= df[agg_change.filter_variable_name] <= agg_change.filter_variable_leq + + if agg_change.filter_variable_geq is not None: + mask &= df[agg_change.filter_variable_name] >= agg_change.filter_variable_geq + + # Apply quantile filters if specified + if any([agg_change.filter_variable_quantile_leq, + agg_change.filter_variable_quantile_geq, agg_change.filter_variable_quantile_value]): + + if agg_change.filter_variable_quantile_leq is not None: + # Filter to values <= specified quantile + threshold = df[agg_change.filter_variable_name].quantile(agg_change.filter_variable_quantile_leq) + mask &= df[agg_change.filter_variable_name] <= threshold + + if agg_change.filter_variable_quantile_geq is not None: + # Filter to values >= specified quantile + threshold = df[agg_change.filter_variable_name].quantile(agg_change.filter_variable_quantile_geq) + mask &= df[agg_change.filter_variable_name] >= threshold + + if agg_change.filter_variable_quantile_value is not None: + # Parse quantile value like "top_10%" or "bottom_20%" + if "top" in agg_change.filter_variable_quantile_value.lower(): + pct = float(agg_change.filter_variable_quantile_value.lower().replace("top_", "").replace("%", "")) / 100 + threshold = df[agg_change.filter_variable_name].quantile(1 - pct) + mask &= df[agg_change.filter_variable_name] >= threshold + elif "bottom" in agg_change.filter_variable_quantile_value.lower(): + pct = float(agg_change.filter_variable_quantile_value.lower().replace("bottom_", "").replace("%", "")) / 100 + threshold = df[agg_change.filter_variable_name].quantile(pct) + mask &= df[agg_change.filter_variable_name] <= threshold + + return mask @staticmethod def _compute_single_aggregate( - agg_change: "AggregateChange", simulation: "Simulation" + agg_change: "AggregateChange", + simulation: "Simulation", + filter_mask: pd.Series | None = None ) -> float: """Compute aggregate value for a single simulation.""" + compute_start = time.time() tables = simulation.result # Copy tables to ensure we don't modify original dataframes tables = {k: v.copy() for k, v in tables.items()} @@ -108,36 +501,31 @@ def _compute_single_aggregate( if agg_change.year is None: agg_change.year = simulation.dataset.year - # Apply filters - if agg_change.filter_variable_name is not None: - if agg_change.filter_variable_name not in df.columns: - raise ValueError( - f"Filter variable {agg_change.filter_variable_name} not found in entity {agg_change.entity}" - ) - if agg_change.filter_variable_value is not None: - df = df[ - df[agg_change.filter_variable_name] - == agg_change.filter_variable_value - ] - if agg_change.filter_variable_leq is not None: - df = df[ - df[agg_change.filter_variable_name] <= agg_change.filter_variable_leq - ] - if agg_change.filter_variable_geq is not None: - df = df[ - df[agg_change.filter_variable_name] >= agg_change.filter_variable_geq - ] + # Apply the pre-computed filter mask if provided + # This ensures we're using the same subset of entities for both baseline and comparison + if filter_mask is not None: + df = df[filter_mask] + + # Check if we have any data left after filtering + if len(df) == 0: + # Return 0 for empty datasets + return 0.0 # Compute aggregate - if agg_change.aggregate_function == AggregateType.SUM: - value = float(df[agg_change.variable_name].sum()) - elif agg_change.aggregate_function == AggregateType.MEAN: - value = float(df[agg_change.variable_name].mean()) - elif agg_change.aggregate_function == AggregateType.MEDIAN: - value = float(df[agg_change.variable_name].median()) - elif agg_change.aggregate_function == AggregateType.COUNT: - value = float((df[agg_change.variable_name] > 0).sum()) - else: - raise ValueError(f"Unknown aggregate function: {agg_change.aggregate_function}") + try: + if agg_change.aggregate_function == AggregateType.SUM: + value = float(df[agg_change.variable_name].sum()) + elif agg_change.aggregate_function == AggregateType.MEAN: + value = float(df[agg_change.variable_name].mean()) + elif agg_change.aggregate_function == AggregateType.MEDIAN: + value = float(df[agg_change.variable_name].median()) + elif agg_change.aggregate_function == AggregateType.COUNT: + value = float((df[agg_change.variable_name] > 0).sum()) + else: + raise ValueError(f"Unknown aggregate function: {agg_change.aggregate_function}") + except (ZeroDivisionError, ValueError) as e: + # Handle cases where weights sum to zero or other computation errors + # Return 0 for these edge cases + return 0.0 return value \ No newline at end of file From 0a4b69f31da7aa4a8cf6f3980aba0f1dde28353c Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Thu, 2 Oct 2025 23:14:39 +0100 Subject: [PATCH 06/35] Add user associations --- src/policyengine/database/__init__.py | 18 ++++++++++ src/policyengine/database/database.py | 10 ++++++ .../database/user_dataset_table.py | 33 +++++++++++++++++++ .../database/user_dynamic_table.py | 33 +++++++++++++++++++ .../database/user_policy_table.py | 33 +++++++++++++++++++ .../database/user_report_table.py | 33 +++++++++++++++++++ .../database/user_simulation_table.py | 33 +++++++++++++++++++ 7 files changed, 193 insertions(+) create mode 100644 src/policyengine/database/user_dataset_table.py create mode 100644 src/policyengine/database/user_dynamic_table.py create mode 100644 src/policyengine/database/user_policy_table.py create mode 100644 src/policyengine/database/user_report_table.py create mode 100644 src/policyengine/database/user_simulation_table.py diff --git a/src/policyengine/database/__init__.py b/src/policyengine/database/__init__.py index 88e1a21b..293a41b1 100644 --- a/src/policyengine/database/__init__.py +++ b/src/policyengine/database/__init__.py @@ -29,6 +29,12 @@ from .report_element_table import ReportElementTable, report_element_table_link from .aggregate import AggregateTable, aggregate_table_link from .aggregate_change import AggregateChangeTable, aggregate_change_table_link +from .user_table import UserTable, user_table_link +from .user_policy_table import UserPolicyTable, user_policy_table_link +from .user_dynamic_table import UserDynamicTable, user_dynamic_table_link +from .user_dataset_table import UserDatasetTable, user_dataset_table_link +from .user_simulation_table import UserSimulationTable, user_simulation_table_link +from .user_report_table import UserReportTable, user_report_table_link __all__ = [ "Database", @@ -49,6 +55,12 @@ "ReportElementTable", "AggregateTable", "AggregateChangeTable", + "UserTable", + "UserPolicyTable", + "UserDynamicTable", + "UserDatasetTable", + "UserSimulationTable", + "UserReportTable", # Links "model_table_link", "model_version_table_link", @@ -65,4 +77,10 @@ "report_element_table_link", "aggregate_table_link", "aggregate_change_table_link", + "user_table_link", + "user_policy_table_link", + "user_dynamic_table_link", + "user_dataset_table_link", + "user_simulation_table_link", + "user_report_table_link", ] diff --git a/src/policyengine/database/database.py b/src/policyengine/database/database.py index 2ae77e1c..dc5bb539 100644 --- a/src/policyengine/database/database.py +++ b/src/policyengine/database/database.py @@ -19,6 +19,11 @@ from .report_table import report_table_link from .simulation_table import simulation_table_link from .user_table import user_table_link +from .user_policy_table import user_policy_table_link +from .user_dynamic_table import user_dynamic_table_link +from .user_dataset_table import user_dataset_table_link +from .user_simulation_table import user_simulation_table_link +from .user_report_table import user_report_table_link from .versioned_dataset_table import versioned_dataset_table_link @@ -46,6 +51,11 @@ def __init__(self, url: str): simulation_table_link, aggregate_table_link, user_table_link, + user_policy_table_link, + user_dynamic_table_link, + user_dataset_table_link, + user_simulation_table_link, + user_report_table_link, report_table_link, report_element_table_link, ]: diff --git a/src/policyengine/database/user_dataset_table.py b/src/policyengine/database/user_dataset_table.py new file mode 100644 index 00000000..b2ff27b8 --- /dev/null +++ b/src/policyengine/database/user_dataset_table.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel +from pydantic import BaseModel + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class UserDatasetTable(SQLModel, table=True): + __tablename__ = "user_datasets" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + user_id: str = Field(foreign_key="users.id", nullable=False) + dataset_id: str = Field(foreign_key="datasets.id", nullable=False) + custom_name: str | None = Field(default=None) + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + + +# Create a dummy model class for the table link +class UserDataset(BaseModel): + pass + + +user_dataset_table_link = TableLink( + model_cls=UserDataset, + table_cls=UserDatasetTable, +) \ No newline at end of file diff --git a/src/policyengine/database/user_dynamic_table.py b/src/policyengine/database/user_dynamic_table.py new file mode 100644 index 00000000..e307c6c8 --- /dev/null +++ b/src/policyengine/database/user_dynamic_table.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel +from pydantic import BaseModel + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class UserDynamicTable(SQLModel, table=True): + __tablename__ = "user_dynamics" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + user_id: str = Field(foreign_key="users.id", nullable=False) + dynamic_id: str = Field(foreign_key="dynamics.id", nullable=False) + custom_name: str | None = Field(default=None) + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + + +# Create a dummy model class for the table link +class UserDynamic(BaseModel): + pass + + +user_dynamic_table_link = TableLink( + model_cls=UserDynamic, + table_cls=UserDynamicTable, +) \ No newline at end of file diff --git a/src/policyengine/database/user_policy_table.py b/src/policyengine/database/user_policy_table.py new file mode 100644 index 00000000..5fa14ed6 --- /dev/null +++ b/src/policyengine/database/user_policy_table.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel +from pydantic import BaseModel + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class UserPolicyTable(SQLModel, table=True): + __tablename__ = "user_policies" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + user_id: str = Field(foreign_key="users.id", nullable=False) + policy_id: str = Field(foreign_key="policies.id", nullable=False) + custom_name: str | None = Field(default=None) + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + + +# Create a dummy model class for the table link +class UserPolicy(BaseModel): + pass + + +user_policy_table_link = TableLink( + model_cls=UserPolicy, + table_cls=UserPolicyTable, +) \ No newline at end of file diff --git a/src/policyengine/database/user_report_table.py b/src/policyengine/database/user_report_table.py new file mode 100644 index 00000000..e09f3389 --- /dev/null +++ b/src/policyengine/database/user_report_table.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel +from pydantic import BaseModel + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class UserReportTable(SQLModel, table=True): + __tablename__ = "user_reports" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + user_id: str = Field(foreign_key="users.id", nullable=False) + report_id: str = Field(foreign_key="reports.id", nullable=False) + custom_name: str | None = Field(default=None) + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + + +# Create a dummy model class for the table link +class UserReport(BaseModel): + pass + + +user_report_table_link = TableLink( + model_cls=UserReport, + table_cls=UserReportTable, +) \ No newline at end of file diff --git a/src/policyengine/database/user_simulation_table.py b/src/policyengine/database/user_simulation_table.py new file mode 100644 index 00000000..75c39fba --- /dev/null +++ b/src/policyengine/database/user_simulation_table.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from sqlmodel import Field, SQLModel +from pydantic import BaseModel + +from .link import TableLink + +if TYPE_CHECKING: + from .database import Database + + +class UserSimulationTable(SQLModel, table=True): + __tablename__ = "user_simulations" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + user_id: str = Field(foreign_key="users.id", nullable=False) + simulation_id: str = Field(foreign_key="simulations.id", nullable=False) + custom_name: str | None = Field(default=None) + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + + +# Create a dummy model class for the table link +class UserSimulation(BaseModel): + pass + + +user_simulation_table_link = TableLink( + model_cls=UserSimulation, + table_cls=UserSimulationTable, +) \ No newline at end of file From 523d2af6d8ee4804f764f288e0257c986481405a Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Fri, 3 Oct 2025 17:57:58 +0100 Subject: [PATCH 07/35] Working sim impacts! --- docs/quickstart.ipynb | 62 +++++++++---------- src/policyengine/database/database.py | 26 ++++++++ .../database/parameter_value_table.py | 3 +- src/policyengine/database/simulation_table.py | 8 ++- src/policyengine/database/user_table.py | 3 + src/policyengine/models/parameter_value.py | 2 +- src/policyengine/models/policyengine_uk.py | 4 ++ src/policyengine/models/policyengine_us.py | 4 ++ src/policyengine/models/simulation.py | 11 +++- src/policyengine/models/user.py | 1 + 10 files changed, 86 insertions(+), 38 deletions(-) diff --git a/docs/quickstart.ipynb b/docs/quickstart.ipynb index f34568f8..8c497bff 100644 --- a/docs/quickstart.ipynb +++ b/docs/quickstart.ipynb @@ -66,16 +66,16 @@ "£500,000" ], "y": [ - 6598054.774273342, - 10307676.305816924, - 7178099.41165927, - 4267341.442682308, - 1717525.567486669, - 1319705.9065854247, - 326077.61111739336, - 187608.23132836912, - 63106.63353048405, - 41838.373842805624 + 6530423.253505196, + 10205681.438694796, + 6918333.897778195, + 4101047.3396776896, + 1656640.5745191968, + 1312315.5343185724, + 706991.8843209555, + 277644.11414299323, + 72024.26234725268, + 34894.54357677698 ] } ], @@ -1110,16 +1110,16 @@ "£500,000" ], "y": [ - 6598054.774273342, - 10307676.305816924, - 7178099.41165927, - 4267341.442682308, - 1717525.567486669, - 1319705.9065854247, - 326077.61111739336, - 187608.23132836912, - 63106.63353048405, - 41838.373842805624 + 6530423.253505196, + 10205681.438694796, + 6918333.897778195, + 4101047.3396776896, + 1656640.5745191968, + 1312315.5343185724, + 706991.8843209555, + 277644.11414299323, + 72024.26234725268, + 34894.54357677698 ] }, { @@ -1148,16 +1148,16 @@ "£500,000" ], "y": [ - 6142729.718842472, - 10318841.418840772, - 6940624.716328709, - 4465347.839217288, - 2002931.6650648424, - 1472352.5016867262, - 341808.24952948757, - 218180.35939976107, - 63106.63353048405, - 41838.373842805624 + 6131768.670854956, + 10113037.29630455, + 6805540.101212463, + 4269018.780282864, + 1910859.1447649547, + 1503954.6294220393, + 715336.0493238519, + 283341.39915852225, + 72024.26234725268, + 34894.54357677698 ] } ], @@ -2224,7 +2224,7 @@ { "data": { "text/plain": [ - "Policy(id='118195fb-4485-4616-bc9d-ed5ea540c300', name='Increase personal allowance to £20,000', description='A policy to increase the personal allowance for income tax to £20,000.', parameter_values=[ParameterValue(id='a8cfa8e2-f339-4896-a90f-ef470e365526', parameter=Parameter(id='gov.hmrc.income_tax.allowances.personal_allowance.amount', description=None, data_type=None, model=Model(id='policyengine_uk', name='PolicyEngine UK', description=\"PolicyEngine's open-source tax-benefit microsimulation model.\", simulation_function=), label=None, unit=None), value=20000, start_date=datetime.datetime(2029, 1, 1, 0, 0), end_date=None)], simulation_modifier=None, created_at=datetime.datetime(2025, 9, 29, 18, 21, 46, 643799), updated_at=datetime.datetime(2025, 9, 29, 18, 21, 46, 643805))" + "Policy(id='77545b6d-1294-4f7d-8646-cbc6d9d4b054', name='Increase personal allowance to £20,000', description='A policy to increase the personal allowance for income tax to £20,000.', parameter_values=[ParameterValue(id='0416cef8-f6f1-4925-bba6-60c2527838ae', parameter=Parameter(id='gov.hmrc.income_tax.allowances.personal_allowance.amount', description=None, data_type=None, model=Model(id='policyengine_uk', name='PolicyEngine UK', description=\"PolicyEngine's open-source tax-benefit microsimulation model.\", simulation_function=), label=None, unit=None), value=20000, start_date=datetime.datetime(2029, 1, 1, 0, 0), end_date=None)], simulation_modifier=None, created_at=datetime.datetime(2025, 10, 3, 17, 46, 47, 141804), updated_at=datetime.datetime(2025, 10, 3, 17, 46, 47, 141809))" ] }, "execution_count": 5, diff --git a/src/policyengine/database/database.py b/src/policyengine/database/database.py index dc5bb539..38d4e0de 100644 --- a/src/policyengine/database/database.py +++ b/src/policyengine/database/database.py @@ -79,6 +79,32 @@ def reset(self): self.drop_tables() self.create_tables() + def ensure_anonymous_user(self): + """Ensure the anonymous user exists in the database for development.""" + from datetime import datetime + from policyengine.models.user import User + from sqlmodel import select + from .user_table import UserTable + + # Check if anonymous user exists + stmt = select(UserTable).where(UserTable.id == "anonymous") + existing = self.session.exec(stmt).first() + + if not existing: + # Create anonymous user with UK model as default + anonymous_user = UserTable( + id="anonymous", + username="anonymous", + first_name="Anonymous", + last_name="User", + email=None, + current_model_id="policyengine_uk", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ) + self.session.add(anonymous_user) + self.session.commit() + def __enter__(self): """Context manager entry - creates a session.""" self.session = Session(self.engine) diff --git a/src/policyengine/database/parameter_value_table.py b/src/policyengine/database/parameter_value_table.py index 7bd02d0a..6bfd60dd 100644 --- a/src/policyengine/database/parameter_value_table.py +++ b/src/policyengine/database/parameter_value_table.py @@ -81,8 +81,7 @@ def convert_to_model(self, database: "Database" = None) -> ParameterValue: ParameterTable.model_id == self.model_id ) ).first() - if param_table: - parameter = param_table.convert_to_model(database) + parameter = param_table.convert_to_model(database) # Handle special string values value = self.value diff --git a/src/policyengine/database/simulation_table.py b/src/policyengine/database/simulation_table.py index de45a419..de6eae58 100644 --- a/src/policyengine/database/simulation_table.py +++ b/src/policyengine/database/simulation_table.py @@ -33,6 +33,7 @@ class SimulationTable(SQLModel, table=True): ) result: bytes | None = Field(default=None) + error: str | None = Field(default=None) @classmethod def convert_from_model(cls, model: Simulation, database: "Database" = None) -> "SimulationTable": @@ -68,6 +69,7 @@ def convert_from_model(cls, model: Simulation, database: "Database" = None) -> " model_id=model.model.id if model.model else None, model_version_id=model.model_version.id if model.model_version else None, result=compress_data(model.result) if model.result else None, + error=getattr(model, 'error', None), ) # Handle nested aggregates if database is provided @@ -205,7 +207,7 @@ def convert_to_model(self, database: "Database" = None) -> Simulation: agg_model = agg_table.convert_to_model(None) aggregates.append(agg_model) - return Simulation( + sim = Simulation( id=self.id, created_at=self.created_at, updated_at=self.updated_at, @@ -217,6 +219,10 @@ def convert_to_model(self, database: "Database" = None) -> Simulation: result=decompress_data(self.result) if self.result else None, aggregates=aggregates, ) + # Add error as dynamic attribute if present + if self.error: + sim.error = self.error + return sim simulation_table_link = TableLink( diff --git a/src/policyengine/database/user_table.py b/src/policyengine/database/user_table.py index d663ac8f..e6966f17 100644 --- a/src/policyengine/database/user_table.py +++ b/src/policyengine/database/user_table.py @@ -22,6 +22,7 @@ class UserTable(SQLModel, table=True, extend_existing=True): first_name: str | None = Field(default=None) last_name: str | None = Field(default=None) email: str | None = Field(default=None) + current_model_id: str = Field(default="policyengine_uk") created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) @@ -34,6 +35,7 @@ def convert_from_model(cls, model: User, database: "Database" = None) -> "UserTa first_name=model.first_name, last_name=model.last_name, email=model.email, + current_model_id=model.current_model_id, created_at=model.created_at, updated_at=model.updated_at, ) @@ -46,6 +48,7 @@ def convert_to_model(self, database: "Database" = None) -> User: first_name=self.first_name, last_name=self.last_name, email=self.email, + current_model_id=self.current_model_id, created_at=self.created_at, updated_at=self.updated_at, ) diff --git a/src/policyengine/models/parameter_value.py b/src/policyengine/models/parameter_value.py index a7867557..c997d794 100644 --- a/src/policyengine/models/parameter_value.py +++ b/src/policyengine/models/parameter_value.py @@ -8,7 +8,7 @@ class ParameterValue(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) - parameter: Parameter + parameter: Parameter | None = None value: float | int | str | bool | list | None = None start_date: datetime end_date: datetime | None = None diff --git a/src/policyengine/models/policyengine_uk.py b/src/policyengine/models/policyengine_uk.py index 5b97ccfb..22c72546 100644 --- a/src/policyengine/models/policyengine_uk.py +++ b/src/policyengine/models/policyengine_uk.py @@ -28,6 +28,8 @@ def run_policyengine_uk( def simulation_modifier(sim: Microsimulation): if policy is not None and len(policy.parameter_values) > 0: for parameter_value in policy.parameter_values: + if parameter_value.parameter is None: + raise ValueError(f"Parameter value {parameter_value.id} has no parameter set - the policy contains invalid data") sim.tax_benefit_system.parameters.get_child( parameter_value.parameter.id ).update( @@ -40,6 +42,8 @@ def simulation_modifier(sim: Microsimulation): if dynamic is not None and len(dynamic.parameter_values) > 0: for parameter_value in dynamic.parameter_values: + if parameter_value.parameter is None: + raise ValueError(f"Parameter value {parameter_value.id} has no parameter set - the dynamic contains invalid data") sim.tax_benefit_system.parameters.get_child( parameter_value.parameter.id ).update( diff --git a/src/policyengine/models/policyengine_us.py b/src/policyengine/models/policyengine_us.py index 9e2eeb7d..807859f0 100644 --- a/src/policyengine/models/policyengine_us.py +++ b/src/policyengine/models/policyengine_us.py @@ -35,6 +35,8 @@ def run_policyengine_us( def simulation_modifier(sim: Microsimulation): if policy is not None and len(policy.parameter_values) > 0: for parameter_value in policy.parameter_values: + if parameter_value.parameter is None: + raise ValueError(f"Parameter value {parameter_value.id} has no parameter set - the policy contains invalid data") sim.tax_benefit_system.parameters.get_child( parameter_value.parameter.id ).update( @@ -47,6 +49,8 @@ def simulation_modifier(sim: Microsimulation): if dynamic is not None and len(dynamic.parameter_values) > 0: for parameter_value in dynamic.parameter_values: + if parameter_value.parameter is None: + raise ValueError(f"Parameter value {parameter_value.id} has no parameter set - the dynamic contains invalid data") sim.tax_benefit_system.parameters.get_child( parameter_value.parameter.id ).update( diff --git a/src/policyengine/models/simulation.py b/src/policyengine/models/simulation.py index a6ed7a5a..7c0abad1 100644 --- a/src/policyengine/models/simulation.py +++ b/src/policyengine/models/simulation.py @@ -18,14 +18,19 @@ class Simulation(BaseModel): policy: Policy | None = None dynamic: Dynamic | None = None - dataset: Dataset + dataset: Dataset | None = None - model: Model - model_version: ModelVersion + model: Model | None = None + model_version: ModelVersion | None = None result: Any | None = None aggregates: list = Field(default_factory=list) # Will be list[Aggregate] but avoid circular import def run(self): + if not self.model: + raise ValueError("Cannot run simulation: model is not set") + if not self.dataset: + raise ValueError("Cannot run simulation: dataset is not set") + self.result = self.model.simulation_function( dataset=self.dataset, policy=self.policy, diff --git a/src/policyengine/models/user.py b/src/policyengine/models/user.py index dee924e1..af29adff 100644 --- a/src/policyengine/models/user.py +++ b/src/policyengine/models/user.py @@ -10,5 +10,6 @@ class User(BaseModel): first_name: str | None = None last_name: str | None = None email: str | None = None + current_model_id: str = "policyengine_uk" # Default to UK model created_at: datetime | None = None updated_at: datetime | None = None From e6cb739e052aed5a281ae6a353fccf8fe313e635 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 6 Oct 2025 14:31:23 +0100 Subject: [PATCH 08/35] Move nonneeded code --- src/policyengine/database/__init__.py | 24 ---- src/policyengine/database/database.py | 42 +----- .../database/report_element_table.py | 106 ---------------- src/policyengine/database/report_table.py | 120 ------------------ .../database/user_dataset_table.py | 33 ----- .../database/user_dynamic_table.py | 33 ----- .../database/user_policy_table.py | 33 ----- .../database/user_report_table.py | 33 ----- .../database/user_simulation_table.py | 33 ----- src/policyengine/database/user_table.py | 60 --------- src/policyengine/models/__init__.py | 3 - src/policyengine/models/report.py | 20 --- src/policyengine/models/report_element.py | 38 ------ src/policyengine/models/user.py | 15 --- 14 files changed, 2 insertions(+), 591 deletions(-) delete mode 100644 src/policyengine/database/report_element_table.py delete mode 100644 src/policyengine/database/report_table.py delete mode 100644 src/policyengine/database/user_dataset_table.py delete mode 100644 src/policyengine/database/user_dynamic_table.py delete mode 100644 src/policyengine/database/user_policy_table.py delete mode 100644 src/policyengine/database/user_report_table.py delete mode 100644 src/policyengine/database/user_simulation_table.py delete mode 100644 src/policyengine/database/user_table.py delete mode 100644 src/policyengine/models/report.py delete mode 100644 src/policyengine/models/report_element.py delete mode 100644 src/policyengine/models/user.py diff --git a/src/policyengine/database/__init__.py b/src/policyengine/database/__init__.py index 293a41b1..69f34d89 100644 --- a/src/policyengine/database/__init__.py +++ b/src/policyengine/database/__init__.py @@ -25,16 +25,8 @@ VersionedDatasetTable, versioned_dataset_table_link, ) -from .report_table import ReportTable, report_table_link -from .report_element_table import ReportElementTable, report_element_table_link from .aggregate import AggregateTable, aggregate_table_link from .aggregate_change import AggregateChangeTable, aggregate_change_table_link -from .user_table import UserTable, user_table_link -from .user_policy_table import UserPolicyTable, user_policy_table_link -from .user_dynamic_table import UserDynamicTable, user_dynamic_table_link -from .user_dataset_table import UserDatasetTable, user_dataset_table_link -from .user_simulation_table import UserSimulationTable, user_simulation_table_link -from .user_report_table import UserReportTable, user_report_table_link __all__ = [ "Database", @@ -51,16 +43,8 @@ "BaselineParameterValueTable", "BaselineVariableTable", "SimulationTable", - "ReportTable", - "ReportElementTable", "AggregateTable", "AggregateChangeTable", - "UserTable", - "UserPolicyTable", - "UserDynamicTable", - "UserDatasetTable", - "UserSimulationTable", - "UserReportTable", # Links "model_table_link", "model_version_table_link", @@ -73,14 +57,6 @@ "baseline_parameter_value_table_link", "baseline_variable_table_link", "simulation_table_link", - "report_table_link", - "report_element_table_link", "aggregate_table_link", "aggregate_change_table_link", - "user_table_link", - "user_policy_table_link", - "user_dynamic_table_link", - "user_dataset_table_link", - "user_simulation_table_link", - "user_report_table_link", ] diff --git a/src/policyengine/database/database.py b/src/policyengine/database/database.py index 38d4e0de..c1c20755 100644 --- a/src/policyengine/database/database.py +++ b/src/policyengine/database/database.py @@ -15,15 +15,7 @@ from .parameter_table import parameter_table_link from .parameter_value_table import parameter_value_table_link from .policy_table import policy_table_link -from .report_element_table import report_element_table_link -from .report_table import report_table_link from .simulation_table import simulation_table_link -from .user_table import user_table_link -from .user_policy_table import user_policy_table_link -from .user_dynamic_table import user_dynamic_table_link -from .user_dataset_table import user_dataset_table_link -from .user_simulation_table import user_simulation_table_link -from .user_report_table import user_report_table_link from .versioned_dataset_table import versioned_dataset_table_link @@ -50,14 +42,6 @@ def __init__(self, url: str): baseline_variable_table_link, simulation_table_link, aggregate_table_link, - user_table_link, - user_policy_table_link, - user_dynamic_table_link, - user_dataset_table_link, - user_simulation_table_link, - user_report_table_link, - report_table_link, - report_element_table_link, ]: self.register_table(link) @@ -80,30 +64,8 @@ def reset(self): self.create_tables() def ensure_anonymous_user(self): - """Ensure the anonymous user exists in the database for development.""" - from datetime import datetime - from policyengine.models.user import User - from sqlmodel import select - from .user_table import UserTable - - # Check if anonymous user exists - stmt = select(UserTable).where(UserTable.id == "anonymous") - existing = self.session.exec(stmt).first() - - if not existing: - # Create anonymous user with UK model as default - anonymous_user = UserTable( - id="anonymous", - username="anonymous", - first_name="Anonymous", - last_name="User", - email=None, - current_model_id="policyengine_uk", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ) - self.session.add(anonymous_user) - self.session.commit() + """Deprecated: This method no longer exists as user management has been moved to the API layer.""" + pass def __enter__(self): """Context manager entry - creates a session.""" diff --git a/src/policyengine/database/report_element_table.py b/src/policyengine/database/report_element_table.py deleted file mode 100644 index cc69e83e..00000000 --- a/src/policyengine/database/report_element_table.py +++ /dev/null @@ -1,106 +0,0 @@ -import uuid -from datetime import datetime - -from sqlmodel import Field, SQLModel, Column, JSON -from typing import TYPE_CHECKING - -from policyengine.models.report_element import ReportElement - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ReportElementTable(SQLModel, table=True, extend_existing=True): - __tablename__ = "report_elements" - - id: str = Field( - primary_key=True, default_factory=lambda: str(uuid.uuid4()) - ) - label: str = Field(nullable=False) - type: str = Field(nullable=False) # "chart" or "markdown" - - # Data source - data_table: str | None = Field(default=None) # "aggregates" or "aggregate_changes" - - # Chart configuration - chart_type: str | None = Field( - default=None - ) # "bar", "line", "scatter", "area", "pie" - x_axis_variable: str | None = Field(default=None) - y_axis_variable: str | None = Field(default=None) - group_by: str | None = Field(default=None) - color_by: str | None = Field(default=None) - size_by: str | None = Field(default=None) - - # Markdown specific - markdown_content: str | None = Field(default=None) - - # Metadata - report_id: str | None = Field(default=None, foreign_key="reports.id") - user_id: str | None = Field(default=None, foreign_key="users.id") - model_version_id: str | None = Field(default=None, foreign_key="model_versions.id") - position: int | None = Field(default=None) - visible: bool | None = Field(default=True) - custom_config: dict | None = Field(default=None, sa_column=Column(JSON)) - report_element_metadata: dict | None = Field(default=None, sa_column=Column(JSON)) - created_at: datetime = Field(default_factory=datetime.utcnow) - updated_at: datetime = Field(default_factory=datetime.utcnow) - - @classmethod - def convert_from_model(cls, model: ReportElement, database: "Database" = None) -> "ReportElementTable": - """Convert a ReportElement instance to a ReportElementTable instance.""" - return cls( - id=model.id, - label=model.label, - type=model.type, - data_table=model.data_table, - chart_type=model.chart_type, - x_axis_variable=model.x_axis_variable, - y_axis_variable=model.y_axis_variable, - group_by=model.group_by, - color_by=model.color_by, - size_by=model.size_by, - markdown_content=model.markdown_content, - report_id=model.report_id, - user_id=model.user_id, - model_version_id=model.model_version_id, - position=model.position, - visible=model.visible, - custom_config=model.custom_config, - report_element_metadata=model.report_element_metadata, - created_at=model.created_at, - updated_at=model.updated_at, - ) - - def convert_to_model(self, database: "Database" = None) -> ReportElement: - """Convert this ReportElementTable instance to a ReportElement instance.""" - return ReportElement( - id=self.id, - label=self.label, - type=self.type, - data_table=self.data_table, - chart_type=self.chart_type, - x_axis_variable=self.x_axis_variable, - y_axis_variable=self.y_axis_variable, - group_by=self.group_by, - color_by=self.color_by, - size_by=self.size_by, - markdown_content=self.markdown_content, - report_id=self.report_id, - user_id=self.user_id, - model_version_id=self.model_version_id, - position=self.position, - visible=self.visible, - custom_config=self.custom_config, - report_element_metadata=self.report_element_metadata, - created_at=self.created_at, - updated_at=self.updated_at, - ) - - -report_element_table_link = TableLink( - model_cls=ReportElement, - table_cls=ReportElementTable, -) diff --git a/src/policyengine/database/report_table.py b/src/policyengine/database/report_table.py deleted file mode 100644 index 79c11cf0..00000000 --- a/src/policyengine/database/report_table.py +++ /dev/null @@ -1,120 +0,0 @@ -import uuid -from datetime import datetime - -from sqlmodel import Field, SQLModel -from typing import TYPE_CHECKING - -from policyengine.models.report import Report - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ReportTable(SQLModel, table=True, extend_existing=True): - __tablename__ = "reports" - - id: str = Field( - primary_key=True, default_factory=lambda: str(uuid.uuid4()) - ) - label: str = Field(nullable=False) - created_at: datetime = Field(default_factory=datetime.utcnow) - - @classmethod - def convert_from_model(cls, model: Report, database: "Database" = None) -> "ReportTable": - """Convert a Report instance to a ReportTable instance.""" - report_table = cls( - id=model.id, - label=model.label, - created_at=model.created_at, - ) - - # Handle nested report elements if database is provided - if database and model.elements: - from .report_element_table import ReportElementTable - from sqlmodel import select - - # First ensure the report table is saved to the database - # This is necessary so the foreign key constraint is satisfied - # Check if it already exists - existing_report = database.session.exec( - select(ReportTable).where(ReportTable.id == model.id) - ).first() - - if not existing_report: - database.session.add(report_table) - database.session.flush() - - # Track which element IDs we want to keep - desired_elem_ids = {elem.id for elem in model.elements} - - # Delete only elements linked to this report that are NOT in the new list - existing_elems = database.session.exec( - select(ReportElementTable).where(ReportElementTable.report_id == model.id) - ).all() - for elem in existing_elems: - if elem.id not in desired_elem_ids: - database.session.delete(elem) - - # Now save/update the elements - for i, element in enumerate(model.elements): - # Check if this element already exists in the database - existing_elem = database.session.exec( - select(ReportElementTable).where(ReportElementTable.id == element.id) - ).first() - - if existing_elem: - # Update existing element - elem_table = ReportElementTable.convert_from_model(element, database) - existing_elem.report_id = model.id - existing_elem.position = i - existing_elem.label = elem_table.label - existing_elem.type = elem_table.type - existing_elem.markdown_content = elem_table.markdown_content - existing_elem.chart_type = elem_table.chart_type - existing_elem.x_axis_variable = elem_table.x_axis_variable - existing_elem.y_axis_variable = elem_table.y_axis_variable - existing_elem.baseline_simulation_id = elem_table.baseline_simulation_id - existing_elem.reform_simulation_id = elem_table.reform_simulation_id - else: - # Create new element - elem_table = ReportElementTable.convert_from_model(element, database) - elem_table.report_id = model.id # Link to this report - elem_table.position = i # Maintain order - database.session.add(elem_table) - database.session.flush() - - return report_table - - def convert_to_model(self, database: "Database" = None) -> Report: - """Convert this ReportTable instance to a Report instance.""" - # Load nested report elements if database is provided - elements = [] - if database: - from .report_element_table import ReportElementTable - from sqlmodel import select - - # Query for all elements linked to this report, ordered by position - elem_tables = database.session.exec( - select(ReportElementTable) - .where(ReportElementTable.report_id == self.id) - .order_by(ReportElementTable.position) - ).all() - - # Convert each one to a model - for elem_table in elem_tables: - elements.append(elem_table.convert_to_model(database)) - - return Report( - id=self.id, - label=self.label, - created_at=self.created_at, - elements=elements, - ) - - -report_table_link = TableLink( - model_cls=Report, - table_cls=ReportTable, -) diff --git a/src/policyengine/database/user_dataset_table.py b/src/policyengine/database/user_dataset_table.py deleted file mode 100644 index b2ff27b8..00000000 --- a/src/policyengine/database/user_dataset_table.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel -from pydantic import BaseModel - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class UserDatasetTable(SQLModel, table=True): - __tablename__ = "user_datasets" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - user_id: str = Field(foreign_key="users.id", nullable=False) - dataset_id: str = Field(foreign_key="datasets.id", nullable=False) - custom_name: str | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - -# Create a dummy model class for the table link -class UserDataset(BaseModel): - pass - - -user_dataset_table_link = TableLink( - model_cls=UserDataset, - table_cls=UserDatasetTable, -) \ No newline at end of file diff --git a/src/policyengine/database/user_dynamic_table.py b/src/policyengine/database/user_dynamic_table.py deleted file mode 100644 index e307c6c8..00000000 --- a/src/policyengine/database/user_dynamic_table.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel -from pydantic import BaseModel - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class UserDynamicTable(SQLModel, table=True): - __tablename__ = "user_dynamics" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - user_id: str = Field(foreign_key="users.id", nullable=False) - dynamic_id: str = Field(foreign_key="dynamics.id", nullable=False) - custom_name: str | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - -# Create a dummy model class for the table link -class UserDynamic(BaseModel): - pass - - -user_dynamic_table_link = TableLink( - model_cls=UserDynamic, - table_cls=UserDynamicTable, -) \ No newline at end of file diff --git a/src/policyengine/database/user_policy_table.py b/src/policyengine/database/user_policy_table.py deleted file mode 100644 index 5fa14ed6..00000000 --- a/src/policyengine/database/user_policy_table.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel -from pydantic import BaseModel - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class UserPolicyTable(SQLModel, table=True): - __tablename__ = "user_policies" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - user_id: str = Field(foreign_key="users.id", nullable=False) - policy_id: str = Field(foreign_key="policies.id", nullable=False) - custom_name: str | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - -# Create a dummy model class for the table link -class UserPolicy(BaseModel): - pass - - -user_policy_table_link = TableLink( - model_cls=UserPolicy, - table_cls=UserPolicyTable, -) \ No newline at end of file diff --git a/src/policyengine/database/user_report_table.py b/src/policyengine/database/user_report_table.py deleted file mode 100644 index e09f3389..00000000 --- a/src/policyengine/database/user_report_table.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel -from pydantic import BaseModel - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class UserReportTable(SQLModel, table=True): - __tablename__ = "user_reports" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - user_id: str = Field(foreign_key="users.id", nullable=False) - report_id: str = Field(foreign_key="reports.id", nullable=False) - custom_name: str | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - -# Create a dummy model class for the table link -class UserReport(BaseModel): - pass - - -user_report_table_link = TableLink( - model_cls=UserReport, - table_cls=UserReportTable, -) \ No newline at end of file diff --git a/src/policyengine/database/user_simulation_table.py b/src/policyengine/database/user_simulation_table.py deleted file mode 100644 index 75c39fba..00000000 --- a/src/policyengine/database/user_simulation_table.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel -from pydantic import BaseModel - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class UserSimulationTable(SQLModel, table=True): - __tablename__ = "user_simulations" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - user_id: str = Field(foreign_key="users.id", nullable=False) - simulation_id: str = Field(foreign_key="simulations.id", nullable=False) - custom_name: str | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - -# Create a dummy model class for the table link -class UserSimulation(BaseModel): - pass - - -user_simulation_table_link = TableLink( - model_cls=UserSimulation, - table_cls=UserSimulationTable, -) \ No newline at end of file diff --git a/src/policyengine/database/user_table.py b/src/policyengine/database/user_table.py deleted file mode 100644 index e6966f17..00000000 --- a/src/policyengine/database/user_table.py +++ /dev/null @@ -1,60 +0,0 @@ -import uuid -from datetime import datetime - -from sqlmodel import Field, SQLModel -from typing import TYPE_CHECKING - -from policyengine.models.user import User - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class UserTable(SQLModel, table=True, extend_existing=True): - __tablename__ = "users" - - id: str = Field( - primary_key=True, default_factory=lambda: str(uuid.uuid4()) - ) - username: str = Field(nullable=False, unique=True) - first_name: str | None = Field(default=None) - last_name: str | None = Field(default=None) - email: str | None = Field(default=None) - current_model_id: str = Field(default="policyengine_uk") - created_at: datetime = Field(default_factory=datetime.utcnow) - updated_at: datetime = Field(default_factory=datetime.utcnow) - - @classmethod - def convert_from_model(cls, model: User, database: "Database" = None) -> "UserTable": - """Convert a User instance to a UserTable instance.""" - return cls( - id=model.id, - username=model.username, - first_name=model.first_name, - last_name=model.last_name, - email=model.email, - current_model_id=model.current_model_id, - created_at=model.created_at, - updated_at=model.updated_at, - ) - - def convert_to_model(self, database: "Database" = None) -> User: - """Convert this UserTable instance to a User instance.""" - return User( - id=self.id, - username=self.username, - first_name=self.first_name, - last_name=self.last_name, - email=self.email, - current_model_id=self.current_model_id, - created_at=self.created_at, - updated_at=self.updated_at, - ) - - -user_table_link = TableLink( - model_cls=User, - table_cls=UserTable, -) diff --git a/src/policyengine/models/__init__.py b/src/policyengine/models/__init__.py index de5fd8c9..24fb823c 100644 --- a/src/policyengine/models/__init__.py +++ b/src/policyengine/models/__init__.py @@ -24,10 +24,7 @@ from .policyengine_us import ( policyengine_us_model as policyengine_us_model, ) -from .report import Report as Report -from .report_element import ReportElement as ReportElement from .simulation import Simulation as Simulation -from .user import User as User from .versioned_dataset import VersionedDataset as VersionedDataset # Rebuild models to handle circular references diff --git a/src/policyengine/models/report.py b/src/policyengine/models/report.py deleted file mode 100644 index 2ae0cd3b..00000000 --- a/src/policyengine/models/report.py +++ /dev/null @@ -1,20 +0,0 @@ -import uuid -from datetime import datetime -from typing import TYPE_CHECKING, ForwardRef - -from pydantic import BaseModel, Field - -if TYPE_CHECKING: - from .report_element import ReportElement - - -class Report(BaseModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4())) - label: str - created_at: datetime | None = None - elements: list[ForwardRef("ReportElement")] = Field(default_factory=list) - - -# Import after class definition to avoid circular import -from .report_element import ReportElement -Report.model_rebuild() diff --git a/src/policyengine/models/report_element.py b/src/policyengine/models/report_element.py deleted file mode 100644 index 180ec26c..00000000 --- a/src/policyengine/models/report_element.py +++ /dev/null @@ -1,38 +0,0 @@ -import uuid -from datetime import datetime -from typing import Literal - -from pydantic import BaseModel, Field - - -class ReportElement(BaseModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4())) - label: str - type: Literal["chart", "markdown"] - - # Data source - data_table: Literal["aggregates", "aggregate_changes"] | None = None # Which table to pull from - - # Chart configuration - chart_type: ( - Literal["bar", "line", "scatter", "area", "pie", "histogram"] | None - ) = None - x_axis_variable: str | None = None # Column name from the table - y_axis_variable: str | None = None # Column name from the table - group_by: str | None = None # Column to group/split series by - color_by: str | None = None # Column for color mapping - size_by: str | None = None # Column for size mapping (bubble charts) - - # Markdown specific - markdown_content: str | None = None - - # Metadata - report_id: str | None = None - user_id: str | None = None - model_version_id: str | None = None - position: int | None = None - visible: bool | None = True - custom_config: dict | None = None # Additional chart-specific config - report_element_metadata: dict | None = None # General metadata field for flexible data storage - created_at: datetime | None = None - updated_at: datetime | None = None diff --git a/src/policyengine/models/user.py b/src/policyengine/models/user.py deleted file mode 100644 index af29adff..00000000 --- a/src/policyengine/models/user.py +++ /dev/null @@ -1,15 +0,0 @@ -import uuid -from datetime import datetime - -from pydantic import BaseModel, Field - - -class User(BaseModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4())) - username: str - first_name: str | None = None - last_name: str | None = None - email: str | None = None - current_model_id: str = "policyengine_uk" # Default to UK model - created_at: datetime | None = None - updated_at: datetime | None = None From 56f6d3e38e10efed6537553aefc523fae36c3361 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 6 Oct 2025 15:39:05 +0100 Subject: [PATCH 09/35] Add tests --- db_init.py | 15 ++ src/policyengine/database/database.py | 50 ++++- tests/test_database_init.py | 259 ++++++++++++++++++++++++++ tests/test_database_postgres.py | 169 +++++++++++++++++ 4 files changed, 488 insertions(+), 5 deletions(-) create mode 100644 db_init.py create mode 100644 tests/test_database_init.py create mode 100644 tests/test_database_postgres.py diff --git a/db_init.py b/db_init.py new file mode 100644 index 00000000..f6a3e672 --- /dev/null +++ b/db_init.py @@ -0,0 +1,15 @@ +from policyengine.database import Database +from policyengine.models.policyengine_uk import policyengine_uk_latest_version +from policyengine.utils.datasets import create_uk_dataset + +# Load the dataset + +uk_dataset = create_uk_dataset() + +database = Database("postgresql://postgres:postgres@127.0.0.1:54322/postgres") + +# These two lines are not usually needed, but you should use them the first time you set up a new database +database.reset() # Drop and recreate all tables +database.register_model_version( + policyengine_uk_latest_version +) # Add in the model, model version, parameters and baseline parameter values and variables. diff --git a/src/policyengine/database/database.py b/src/policyengine/database/database.py index c1c20755..2eb0fc40 100644 --- a/src/policyengine/database/database.py +++ b/src/policyengine/database/database.py @@ -21,14 +21,17 @@ class Database: url: str - - _model_table_links: list[TableLink] = [] + _model_table_links: list[TableLink] def __init__(self, url: str): self.url = url self.engine = self._create_engine() self.session = Session(self.engine) + # Initialize instance variable for table links + self._model_table_links = [] + + # Register all table links for link in [ model_table_link, model_version_table_link, @@ -48,7 +51,19 @@ def __init__(self, url: str): def _create_engine(self): from sqlmodel import create_engine - return create_engine(self.url, echo=False) + # Configure engine with proper settings for PostgreSQL/Supabase + engine_args = { + "echo": False, + "pool_pre_ping": True, # Verify connections before using + "pool_recycle": 3600, # Recycle connections after 1 hour + } + + # For PostgreSQL, ensure proper connection pooling + if self.url.startswith("postgresql"): + engine_args["pool_size"] = 5 + engine_args["max_overflow"] = 10 + + return create_engine(self.url, **engine_args) def create_tables(self): """Create all database tables.""" @@ -81,9 +96,34 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.session.close() def register_table(self, link: TableLink): + """Register a table link for use with the database. + + Note: This does NOT create the table. Call create_tables() after + registering all tables to create them in the correct order respecting + foreign key dependencies. + + Args: + link: The TableLink to register + """ self._model_table_links.append(link) - # Create the table if not exists - link.table_cls.metadata.create_all(self.engine) + + def verify_tables_exist(self) -> dict[str, bool]: + """Verify that all registered tables exist in the database. + + Returns: + A dictionary mapping table names to whether they exist + """ + from sqlalchemy import inspect as sql_inspect + + inspector = sql_inspect(self.engine) + existing_tables = set(inspector.get_table_names()) + + results = {} + for link in self._model_table_links: + table_name = link.table_cls.__tablename__ + results[table_name] = table_name in existing_tables + + return results def get(self, model_cls: type, **kwargs): """Get a model instance from the database by its attributes.""" diff --git a/tests/test_database_init.py b/tests/test_database_init.py new file mode 100644 index 00000000..b343e91e --- /dev/null +++ b/tests/test_database_init.py @@ -0,0 +1,259 @@ +"""Test database initialization and table creation.""" + +import sys +from pathlib import Path + +import pytest +from sqlalchemy import inspect + +# Add src to path to allow imports +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + + +# Module-level functions for testing (can be pickled) +# Note: Don't use test_ prefix or pytest will try to run them +def sim_func(x): + """Simulation function that can be pickled.""" + return x + + +def sim_func_double(x): + """Simulation function that doubles input.""" + return x * 2 + + +@pytest.fixture +def fresh_database(): + """Create a fresh database instance for each test.""" + from policyengine.database import Database + + # Use in-memory SQLite for testing + db = Database(url="sqlite:///:memory:") + return db + + +def test_database_creates_engine(fresh_database): + """Test that database initialization creates an engine.""" + assert fresh_database.engine is not None + assert fresh_database.url == "sqlite:///:memory:" + + +def test_database_creates_session(fresh_database): + """Test that database initialization creates a session.""" + assert fresh_database.session is not None + + +def test_create_tables_creates_all_registered_tables(fresh_database): + """Test that create_tables() creates all registered tables.""" + fresh_database.create_tables() + + # Get inspector to check actual database tables + inspector = inspect(fresh_database.engine) + actual_tables = inspector.get_table_names() + + # Expected tables based on registered table links + expected_tables = { + "models", + "model_versions", + "datasets", + "versioned_datasets", + "policies", + "dynamics", + "parameters", + "parameter_values", + "baseline_parameter_values", + "baseline_variables", + "simulations", + "aggregates", + } + + # Check that all expected tables exist + for table in expected_tables: + assert table in actual_tables, f"Table {table} was not created" + + +def test_register_table_creates_table(fresh_database): + """Test that register_table registers the table link.""" + # Tables should be registered but NOT created until create_tables() is called + inspector = inspect(fresh_database.engine) + initial_tables = set(inspector.get_table_names()) + + # No tables should exist yet + assert len(initial_tables) == 0 + + # But table links should be registered + assert len(fresh_database._model_table_links) == 12 + + # After calling create_tables(), tables should exist + fresh_database.create_tables() + inspector = inspect(fresh_database.engine) + tables_after = set(inspector.get_table_names()) + assert "models" in tables_after + + +def test_reset_drops_and_recreates_tables(fresh_database): + """Test that reset() drops and recreates all tables.""" + from policyengine.models import Model + + # Create tables first + fresh_database.create_tables() + + # Add some data + model = Model( + id="test_model", + name="Test", + description="Test model", + simulation_function=sim_func, + ) + fresh_database.set(model) + + # Verify data exists + retrieved = fresh_database.get(Model, id="test_model") + assert retrieved is not None + + # Reset the database + fresh_database.reset() + + # Tables should exist but be empty + inspector = inspect(fresh_database.engine) + tables = inspector.get_table_names() + assert "models" in tables + + # Data should be gone + retrieved_after_reset = fresh_database.get(Model, id="test_model") + assert retrieved_after_reset is None + + +def test_drop_tables_removes_all_tables(fresh_database): + """Test that drop_tables() removes all tables.""" + # Create tables first + fresh_database.create_tables() + + # Verify tables exist + inspector = inspect(fresh_database.engine) + tables_before = inspector.get_table_names() + assert len(tables_before) > 0 + + # Drop tables + fresh_database.drop_tables() + + # Verify tables are gone + inspector = inspect(fresh_database.engine) + tables_after = inspector.get_table_names() + assert len(tables_after) == 0 + + +def test_context_manager_commits_on_success(): + """Test that context manager commits on successful operations.""" + from policyengine.database import Database + from policyengine.models import Model + + db = Database(url="sqlite:///:memory:") + db.create_tables() + + # Use context manager to add data + with db as session: + model = Model( + id="test_context_model", + name="Context Test", + description="Testing context manager", + simulation_function=sim_func, + ) + db.set(model, commit=False) # Don't commit inside set + + # Data should be committed after context exit + retrieved = db.get(Model, id="test_context_model") + assert retrieved is not None + assert retrieved.name == "Context Test" + + +def test_context_manager_rolls_back_on_error(): + """Test that context manager rolls back on errors.""" + from policyengine.database import Database + from policyengine.models import Model + + db = Database(url="sqlite:///:memory:") + db.create_tables() + + # Try to use context manager with an error + try: + with db as session: + model = Model( + id="test_rollback_model", + name="Rollback Test", + description="Testing rollback", + simulation_function=sim_func, + ) + db.set(model, commit=False) + # Raise an error to trigger rollback + raise ValueError("Test error") + except ValueError: + pass + + # Data should NOT be in database due to rollback + retrieved = db.get(Model, id="test_rollback_model") + assert retrieved is None + + +def test_database_url_variations(): + """Test that database works with different URL formats.""" + from policyengine.database import Database + + # Test in-memory SQLite + db1 = Database(url="sqlite:///:memory:") + assert db1.engine is not None + + # Test file-based SQLite + db2 = Database(url="sqlite:///test.db") + assert db2.engine is not None + + +def test_all_table_links_registered(fresh_database): + """Test that all expected table links are registered.""" + expected_count = 12 # Based on the number of table links in __init__ + assert len(fresh_database._model_table_links) == expected_count + + # Verify specific table links exist + from policyengine.models import ( + Aggregate, + Dataset, + Dynamic, + Model, + ModelVersion, + Parameter, + Policy, + Simulation, + VersionedDataset, + ) + + model_classes = [link.model_cls for link in fresh_database._model_table_links] + + assert Model in model_classes + assert ModelVersion in model_classes + assert Dataset in model_classes + assert VersionedDataset in model_classes + assert Policy in model_classes + assert Dynamic in model_classes + assert Parameter in model_classes + assert Simulation in model_classes + assert Aggregate in model_classes + + +def test_verify_tables_exist(fresh_database): + """Test the verify_tables_exist method.""" + # Before creating tables + results_before = fresh_database.verify_tables_exist() + # Some tables may exist from register_table calls during __init__ + # So we just check the method runs + + # After creating tables + fresh_database.create_tables() + results_after = fresh_database.verify_tables_exist() + + # All tables should exist now + assert all(results_after.values()), f"Some tables don't exist: {results_after}" + + # Check specific tables + assert results_after.get("models") is True + assert results_after.get("simulations") is True + assert results_after.get("parameters") is True diff --git a/tests/test_database_postgres.py b/tests/test_database_postgres.py new file mode 100644 index 00000000..1af9796d --- /dev/null +++ b/tests/test_database_postgres.py @@ -0,0 +1,169 @@ +"""Test database with PostgreSQL/Supabase connection. + +These tests verify that table creation and commits work properly with PostgreSQL, +which is what Supabase uses. +""" + +import sys +from pathlib import Path + +import pytest +from sqlalchemy import inspect, text + +# Add src to path to allow imports +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + + +# Module-level functions for testing (can be pickled) +def sim_func(x): + """Simulation function that can be pickled.""" + return x + + +def sim_func_double(x): + """Simulation function that doubles input.""" + return x * 2 + + +@pytest.fixture +def postgres_database(): + """Create a database instance with a local PostgreSQL connection. + + This requires a local PostgreSQL server running on port 54322. + Skip the test if the connection fails. + """ + from policyengine.database import Database + + try: + db = Database(url="postgresql://postgres:postgres@127.0.0.1:54322/postgres") + # Test connection + with db.engine.connect() as conn: + conn.execute(text("SELECT 1")) + return db + except Exception as e: + pytest.skip(f"PostgreSQL not available: {e}") + + +def test_postgres_create_tables(postgres_database): + """Test that create_tables() works with PostgreSQL.""" + # Drop tables first to ensure clean state + postgres_database.drop_tables() + + # Create tables + postgres_database.create_tables() + + # Verify tables exist + inspector = inspect(postgres_database.engine) + actual_tables = inspector.get_table_names() + + expected_tables = { + "models", + "model_versions", + "datasets", + "versioned_datasets", + "policies", + "dynamics", + "parameters", + "parameter_values", + "baseline_parameter_values", + "baseline_variables", + "simulations", + "aggregates", + } + + for table in expected_tables: + assert table in actual_tables, f"Table {table} was not created in PostgreSQL" + + +def test_postgres_insert_and_retrieve(postgres_database): + """Test that data can be inserted and retrieved from PostgreSQL.""" + from policyengine.models import Model + + # Reset database + postgres_database.reset() + + # Create a model + model = Model( + id="postgres_test_model", + name="PostgreSQL Test", + description="Testing PostgreSQL", + simulation_function=sim_func_double, + ) + + # Insert + postgres_database.set(model) + + # Retrieve + retrieved = postgres_database.get(Model, id="postgres_test_model") + + assert retrieved is not None + assert retrieved.id == "postgres_test_model" + assert retrieved.name == "PostgreSQL Test" + assert retrieved.simulation_function(5) == 10 + + +def test_postgres_session_commit(postgres_database): + """Test that session commits work properly with PostgreSQL.""" + from policyengine.models import Model + + # Reset database + postgres_database.reset() + + # Add data without committing + model = Model( + id="commit_test_model", + name="Commit Test", + description="Testing commit", + simulation_function=sim_func, + ) + + # Set with explicit commit + postgres_database.set(model, commit=True) + + # Create a NEW database connection to verify commit + from policyengine.database import Database + + new_db = Database(url="postgresql://postgres:postgres@127.0.0.1:54322/postgres") + retrieved = new_db.get(Model, id="commit_test_model") + + assert retrieved is not None + assert retrieved.name == "Commit Test" + + +def test_postgres_table_persistence(postgres_database): + """Test that tables persist across database reconnections.""" + # Create tables + postgres_database.reset() + + # Close connection + postgres_database.session.close() + + # Create new database instance with same URL + from policyengine.database import Database + + new_db = Database(url="postgresql://postgres:postgres@127.0.0.1:54322/postgres") + + # Tables should still exist + inspector = inspect(new_db.engine) + tables = inspector.get_table_names() + + assert "models" in tables + assert "simulations" in tables + + +def test_postgres_register_model_version(postgres_database): + """Test that register_model_version works with PostgreSQL.""" + # This test verifies that the bulk registration of model versions + # properly commits to PostgreSQL + postgres_database.reset() + + # This would typically use policyengine_uk_latest_version + # For now, we'll just verify the reset worked + inspector = inspect(postgres_database.engine) + tables = inspector.get_table_names() + + assert "models" in tables + assert "model_versions" in tables + assert "parameters" in tables + assert "baseline_parameter_values" in tables + assert "baseline_variables" in tables From 9822cb12b3b53cca7b21521c227be0623dd47a69 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Thu, 9 Oct 2025 12:40:14 +0100 Subject: [PATCH 10/35] Update --- src/policyengine/models/aggregate.py | 321 +++++++++----- src/policyengine/models/aggregate_change.py | 306 +------------- tests/test_aggregate_utils.py | 441 ++++++++++++++++++++ 3 files changed, 682 insertions(+), 386 deletions(-) create mode 100644 tests/test_aggregate_utils.py diff --git a/src/policyengine/models/aggregate.py b/src/policyengine/models/aggregate.py index bb580c94..bf8e110b 100644 --- a/src/policyengine/models/aggregate.py +++ b/src/policyengine/models/aggregate.py @@ -17,28 +17,27 @@ class AggregateType(str, Enum): COUNT = "count" -class Aggregate(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - simulation: "Simulation | None" = None - entity: str | None = None - variable_name: str - year: int | None = None - filter_variable_name: str | None = None - filter_variable_value: str | None = None - filter_variable_leq: float | None = None - filter_variable_geq: float | None = None - filter_variable_quantile_leq: float | None = None - filter_variable_quantile_geq: float | None = None - filter_variable_quantile_value: str | None = None - aggregate_function: Literal[ - AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT - ] - reportelement_id: str | None = None +class AggregateUtils: + """Shared utilities for aggregate calculations.""" - value: float | None = None + @staticmethod + def prepare_tables(simulation: "Simulation") -> dict: + """Prepare dataframes from simulation result once.""" + tables = simulation.result + tables = {k: v.copy() for k, v in tables.items()} + + for table in tables: + tables[table] = pd.DataFrame(tables[table]) + weight_col = f"{table}_weight" + if weight_col in tables[table].columns: + tables[table] = MicroDataFrame( + tables[table], weights=weight_col + ) + + return tables @staticmethod - def _infer_entity(variable_name: str, tables: dict) -> str: + def infer_entity(variable_name: str, tables: dict) -> str: """Infer entity from variable name by checking which table contains it.""" for entity, table in tables.items(): if variable_name in table.columns: @@ -46,10 +45,10 @@ def _infer_entity(variable_name: str, tables: dict) -> str: raise ValueError(f"Variable {variable_name} not found in any entity table") @staticmethod - def _get_entity_link_columns() -> dict: + def get_entity_link_columns() -> dict: """Return mapping of entity relationships for common PolicyEngine models.""" return { - # person -> group entity links + # person -> group entity links (copy values down) "person": { "benunit": "person_benunit_id", "household": "person_household_id", @@ -57,9 +56,130 @@ def _get_entity_link_columns() -> dict: "tax_unit": "person_tax_unit_id", "spm_unit": "person_spm_unit_id", }, - # Group entities don't have direct upward links typically } + @staticmethod + def map_variable_across_entities( + df: pd.DataFrame, + variable_name: str, + source_entity: str, + target_entity: str, + tables: dict + ) -> pd.Series: + """Map a variable from source entity to target entity level.""" + links = AggregateUtils.get_entity_link_columns() + + # Group to person: copy group values to persons using link column + if source_entity != "person" and target_entity == "person": + link_col = links.get("person", {}).get(source_entity) + if link_col is None: + raise ValueError(f"No known link from person to {source_entity}") + + if link_col not in tables["person"].columns: + raise ValueError(f"Link column {link_col} not found in person table") + + # Create mapping: group position (0-based index) -> value + group_values = df[variable_name].values + + # Map to person level using the link column + person_table = tables["person"] + person_group_ids = person_table[link_col].values + + # Map each person to their group's value + result = pd.Series([group_values[int(gid)] if int(gid) < len(group_values) else 0 + for gid in person_group_ids], index=person_table.index) + return result + + # Person to group: sum persons' values to group level + elif source_entity == "person" and target_entity != "person": + link_col = links.get("person", {}).get(target_entity) + if link_col is None: + raise ValueError(f"No known link from person to {target_entity}") + + if link_col not in df.columns: + raise ValueError(f"Link column {link_col} not found in person table") + + # Sum by group - need to align with group table length + grouped = df.groupby(link_col)[variable_name].sum() + + # Create a series aligned with the group table + group_table = tables[target_entity] + result = pd.Series([grouped.get(i, 0) for i in range(len(group_table))], + index=group_table.index) + return result + + # Group to group: try via person as intermediary + elif source_entity != "person" and target_entity != "person": + # Map source -> person -> target + person_values = AggregateUtils.map_variable_across_entities( + df, variable_name, source_entity, "person", tables + ) + # Create temp dataframe with person values + temp_person_df = tables["person"].copy() + temp_person_df[variable_name] = person_values + + return AggregateUtils.map_variable_across_entities( + temp_person_df, variable_name, "person", target_entity, tables + ) + + else: + # Same entity - shouldn't happen but return as-is + return df[variable_name] + + @staticmethod + def compute_aggregate( + variable_series: pd.Series | MicroDataFrame, + aggregate_function: str + ) -> float: + """Compute aggregate value from a series.""" + if len(variable_series) == 0: + return 0.0 + + # Check if this is a weight column being summed from a MicroDataFrame + # If so, use unweighted sum to avoid weight^2 + is_weight_column = ( + isinstance(variable_series, pd.Series) and + hasattr(variable_series, 'name') and + variable_series.name and + 'weight' in str(variable_series.name).lower() + ) + + if aggregate_function == AggregateType.SUM: + if is_weight_column: + # Use unweighted sum for weight columns + return float(pd.Series(variable_series.values).sum()) + return float(variable_series.sum()) + elif aggregate_function == AggregateType.MEAN: + return float(variable_series.mean()) + elif aggregate_function == AggregateType.MEDIAN: + return float(variable_series.median()) + elif aggregate_function == AggregateType.COUNT: + # For COUNT, return the actual number of entries (not weighted) + # Use len() to count all entries regardless of value + return float(len(variable_series)) + else: + raise ValueError(f"Unknown aggregate function: {aggregate_function}") + + +class Aggregate(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + simulation: "Simulation | None" = None + entity: str | None = None + variable_name: str + year: int | None = None + filter_variable_name: str | None = None + filter_variable_value: str | None = None + filter_variable_leq: float | None = None + filter_variable_geq: float | None = None + filter_variable_quantile_leq: float | None = None + filter_variable_quantile_geq: float | None = None + aggregate_function: Literal[ + AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT + ] + reportelement_id: str | None = None + + value: float | None = None + @staticmethod def run(aggregates: list["Aggregate"]) -> list["Aggregate"]: """Process aggregates, handling multiple simulations if necessary.""" @@ -97,100 +217,113 @@ def _process_simulation_aggregates( """Process aggregates for a single simulation.""" results = [] - tables = simulation.result - # Copy tables to ensure we don't modify original dataframes - tables = {k: v.copy() for k, v in tables.items()} - for table in tables: - tables[table] = pd.DataFrame(tables[table]) - weight_col = f"{table}_weight" - if weight_col in tables[table].columns: - tables[table] = MicroDataFrame( - tables[table], weights=weight_col - ) + # Use centralized table preparation + tables = AggregateUtils.prepare_tables(simulation) for agg in aggregates: # Infer entity if not provided if agg.entity is None: - agg.entity = Aggregate._infer_entity(agg.variable_name, tables) + agg.entity = AggregateUtils.infer_entity(agg.variable_name, tables) if agg.entity not in tables: raise ValueError( f"Entity {agg.entity} not found in simulation results" ) - table = tables[agg.entity] - - if agg.variable_name not in table.columns: - raise ValueError( - f"Variable {agg.variable_name} not found in entity {agg.entity}" - ) - - df = table if agg.year is None: agg.year = simulation.dataset.year + # Get the target table + target_table = tables[agg.entity] + + # Handle cross-entity filters + mask = None if agg.filter_variable_name is not None: - if agg.filter_variable_name not in df.columns: + # Find which entity contains the filter variable + filter_entity = None + for entity, table in tables.items(): + if agg.filter_variable_name in table.columns: + filter_entity = entity + break + + if filter_entity is None: raise ValueError( - f"Filter variable {agg.filter_variable_name} not found in entity {agg.entity}" + f"Filter variable {agg.filter_variable_name} not found in any entity" ) + # Get the filter series (mapped if needed) + if filter_entity == agg.entity: + filter_series = tables[agg.entity][agg.filter_variable_name] + else: + # Different entity - map filter variable to target entity + filter_df = tables[filter_entity] + filter_series = AggregateUtils.map_variable_across_entities( + filter_df, + agg.filter_variable_name, + filter_entity, + agg.entity, + tables + ) + + # Build mask + mask = pd.Series([True] * len(target_table), index=target_table.index) + # Apply value/range filters if agg.filter_variable_value is not None: - df = df[ - df[agg.filter_variable_name] - == agg.filter_variable_value - ] + mask &= filter_series == agg.filter_variable_value if agg.filter_variable_leq is not None: - df = df[ - df[agg.filter_variable_name] <= agg.filter_variable_leq - ] + mask &= filter_series <= agg.filter_variable_leq if agg.filter_variable_geq is not None: - df = df[ - df[agg.filter_variable_name] >= agg.filter_variable_geq - ] - - # Apply quantile filters if specified - if any([agg.filter_variable_quantile_leq, - agg.filter_variable_quantile_geq, agg.filter_variable_quantile_value]): - - if agg.filter_variable_quantile_leq is not None: - # Filter to values <= specified quantile - threshold = df[agg.filter_variable_name].quantile(agg.filter_variable_quantile_leq) - df = df[df[agg.filter_variable_name] <= threshold] - - if agg.filter_variable_quantile_geq is not None: - # Filter to values >= specified quantile - threshold = df[agg.filter_variable_name].quantile(agg.filter_variable_quantile_geq) - df = df[df[agg.filter_variable_name] >= threshold] - - if agg.filter_variable_quantile_value is not None: - # Parse quantile value like "top_10%" or "bottom_20%" - if "top" in agg.filter_variable_quantile_value.lower(): - pct = float(agg.filter_variable_quantile_value.lower().replace("top_", "").replace("%", "")) / 100 - threshold = df[agg.filter_variable_name].quantile(1 - pct) - df = df[df[agg.filter_variable_name] >= threshold] - elif "bottom" in agg.filter_variable_quantile_value.lower(): - pct = float(agg.filter_variable_quantile_value.lower().replace("bottom_", "").replace("%", "")) / 100 - threshold = df[agg.filter_variable_name].quantile(pct) - df = df[df[agg.filter_variable_name] <= threshold] - - # Check if we have any data left after filtering - if len(df) == 0: - agg.value = 0.0 + mask &= filter_series >= agg.filter_variable_geq + + # Apply quantile filters + if agg.filter_variable_quantile_leq is not None: + threshold = filter_series.quantile(agg.filter_variable_quantile_leq) + mask &= filter_series <= threshold + if agg.filter_variable_quantile_geq is not None: + threshold = filter_series.quantile(agg.filter_variable_quantile_geq) + mask &= filter_series >= threshold + + # Find which entity contains the variable + variable_entity = None + for entity, table in tables.items(): + if agg.variable_name in table.columns: + variable_entity = entity + break + + if variable_entity is None: + raise ValueError( + f"Variable {agg.variable_name} not found in any entity" + ) + + # Get variable data (mapped if needed) + if variable_entity == agg.entity: + # Same entity - extract from target table + if mask is not None: + # Filter the entire table to preserve MicroDataFrame weights + filtered_table = target_table[mask] + variable_series = filtered_table[agg.variable_name] + else: + variable_series = target_table[agg.variable_name] else: - try: - if agg.aggregate_function == AggregateType.SUM: - agg.value = float(df[agg.variable_name].sum()) - elif agg.aggregate_function == AggregateType.MEAN: - agg.value = float(df[agg.variable_name].mean()) - elif agg.aggregate_function == AggregateType.MEDIAN: - agg.value = float(df[agg.variable_name].median()) - elif agg.aggregate_function == AggregateType.COUNT: - agg.value = float((df[agg.variable_name] > 0).sum()) - except (ZeroDivisionError, ValueError): - # Handle cases where weights sum to zero - agg.value = 0.0 + # Map variable to target entity + source_table = tables[variable_entity] + variable_series = AggregateUtils.map_variable_across_entities( + source_table, + agg.variable_name, + variable_entity, + agg.entity, + tables + ) + # Apply mask after mapping + if mask is not None: + variable_series = variable_series[mask] + + # Compute aggregate using centralized function + agg.value = AggregateUtils.compute_aggregate( + variable_series, + agg.aggregate_function + ) results.append(agg) diff --git a/src/policyengine/models/aggregate_change.py b/src/policyengine/models/aggregate_change.py index 6c0b9ccb..49c05266 100644 --- a/src/policyengine/models/aggregate_change.py +++ b/src/policyengine/models/aggregate_change.py @@ -3,21 +3,15 @@ from uuid import uuid4 import pandas as pd -from microdf import MicroDataFrame from pydantic import BaseModel, Field import time +from .aggregate import AggregateUtils, AggregateType + if TYPE_CHECKING: from policyengine.models import Simulation -class AggregateType(str, Enum): - SUM = "sum" - MEAN = "mean" - MEDIAN = "median" - COUNT = "count" - - class AggregateChange(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) baseline_simulation: "Simulation | None" = None @@ -31,7 +25,6 @@ class AggregateChange(BaseModel): filter_variable_geq: float | None = None filter_variable_quantile_leq: float | None = None filter_variable_quantile_geq: float | None = None - filter_variable_quantile_value: str | None = None aggregate_function: Literal[ AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT ] @@ -42,118 +35,6 @@ class AggregateChange(BaseModel): change: float | None = None relative_change: float | None = None - @staticmethod - def _infer_entity(variable_name: str, filter_variable_name: str | None, tables: dict) -> str: - """Infer entity from the target variable (not the filter variable). - - The entity represents what level we're aggregating at, determined by the target variable. - Filters can be cross-entity and will be mapped if needed. - """ - # Find entity of target variable - for entity, table in tables.items(): - if variable_name in table.columns: - return entity - - raise ValueError(f"Variable {variable_name} not found in any entity table") - - @staticmethod - def _get_entity_link_columns() -> dict: - """Return mapping of entity relationships for common PolicyEngine models.""" - return { - # person -> group entity links (copy values down) - "person": { - "benunit": "person_benunit_id", - "household": "person_household_id", - "family": "person_family_id", - "tax_unit": "person_tax_unit_id", - "spm_unit": "person_spm_unit_id", - }, - } - - @staticmethod - def _map_variable_across_entities( - df: pd.DataFrame, - variable_name: str, - source_entity: str, - target_entity: str, - tables: dict - ) -> pd.Series: - """Map a variable from source entity to target entity level.""" - links = AggregateChange._get_entity_link_columns() - - # Group to person: copy group values to persons using link column - if source_entity != "person" and target_entity == "person": - link_col = links.get("person", {}).get(source_entity) - if link_col is None: - raise ValueError(f"No known link from person to {source_entity}") - - if link_col not in tables["person"].columns: - raise ValueError(f"Link column {link_col} not found in person table") - - # Create mapping: group position (0-based index) -> value - # Most PolicyEngine models have entities numbered 0, 1, 2, ... - group_values = df[variable_name].values - - # Map to person level using the link column - person_table = tables["person"] - person_group_ids = person_table[link_col].values - - # Map each person to their group's value - result = pd.Series([group_values[int(gid)] if int(gid) < len(group_values) else 0 - for gid in person_group_ids], index=person_table.index) - return result - - # Person to group: sum persons' values to group level - elif source_entity == "person" and target_entity != "person": - link_col = links.get("person", {}).get(target_entity) - if link_col is None: - raise ValueError(f"No known link from person to {target_entity}") - - if link_col not in df.columns: - raise ValueError(f"Link column {link_col} not found in person table") - - # Sum by group - need to align with group table length - grouped = df.groupby(link_col)[variable_name].sum() - - # Create a series aligned with the group table - group_table = tables[target_entity] - result = pd.Series([grouped.get(i, 0) for i in range(len(group_table))], - index=group_table.index) - return result - - # Group to group: try via person as intermediary - elif source_entity != "person" and target_entity != "person": - # Map source -> person -> target - person_values = AggregateChange._map_variable_across_entities( - df, variable_name, source_entity, "person", tables - ) - # Create temp dataframe with person values - temp_person_df = tables["person"].copy() - temp_person_df[variable_name] = person_values - - return AggregateChange._map_variable_across_entities( - temp_person_df, variable_name, "person", target_entity, tables - ) - - else: - # Same entity - shouldn't happen but return as-is - return df[variable_name] - - @staticmethod - def _prepare_tables(simulation: "Simulation") -> dict: - """Prepare dataframes from simulation result once.""" - tables = simulation.result - tables = {k: v.copy() for k, v in tables.items()} - - for table in tables: - tables[table] = pd.DataFrame(tables[table]) - weight_col = f"{table}_weight" - if weight_col in tables[table].columns: - tables[table] = MicroDataFrame( - tables[table], weights=weight_col - ) - - return tables @staticmethod def run(aggregate_changes: list["AggregateChange"]) -> list["AggregateChange"]: @@ -186,8 +67,8 @@ def run(aggregate_changes: list["AggregateChange"]) -> list["AggregateChange"]: comparison_sim = group[0].comparison_simulation # Pre-compute simulation dataframes once per batch - baseline_tables = AggregateChange._prepare_tables(baseline_sim) - comparison_tables = AggregateChange._prepare_tables(comparison_sim) + baseline_tables = AggregateUtils.prepare_tables(baseline_sim) + comparison_tables = AggregateUtils.prepare_tables(comparison_sim) prep_time = time.time() print(f"[PERFORMANCE] Table preparation took {prep_time - group_start:.3f} seconds") @@ -196,11 +77,10 @@ def run(aggregate_changes: list["AggregateChange"]) -> list["AggregateChange"]: for idx, agg_change in enumerate(group): item_start = time.time() - # Infer entity if not provided (use filter variable entity if available) + # Infer entity if not provided if agg_change.entity is None: - agg_change.entity = AggregateChange._infer_entity( + agg_change.entity = AggregateUtils.infer_entity( agg_change.variable_name, - agg_change.filter_variable_name, baseline_tables ) @@ -272,7 +152,7 @@ def _get_filter_mask_from_tables( else: # Different entity - need to map filter variable to target entity filter_df = tables[filter_entity] - mapped_filter = AggregateChange._map_variable_across_entities( + mapped_filter = AggregateUtils.map_variable_across_entities( filter_df, agg_change.filter_variable_name, filter_entity, @@ -293,9 +173,7 @@ def _get_filter_mask_from_tables( if agg_change.filter_variable_geq is not None: mask &= filter_series >= agg_change.filter_variable_geq - if any([agg_change.filter_variable_quantile_leq, - agg_change.filter_variable_quantile_geq, agg_change.filter_variable_quantile_value]): - + if agg_change.filter_variable_quantile_leq is not None or agg_change.filter_variable_quantile_geq is not None: if agg_change.filter_variable_quantile_leq is not None: threshold = filter_series.quantile(agg_change.filter_variable_quantile_leq) mask &= filter_series <= threshold @@ -304,16 +182,6 @@ def _get_filter_mask_from_tables( threshold = filter_series.quantile(agg_change.filter_variable_quantile_geq) mask &= filter_series >= threshold - if agg_change.filter_variable_quantile_value is not None: - if "top" in agg_change.filter_variable_quantile_value.lower(): - pct = float(agg_change.filter_variable_quantile_value.lower().replace("top_", "").replace("%", "")) / 100 - threshold = filter_series.quantile(1 - pct) - mask &= filter_series >= threshold - elif "bottom" in agg_change.filter_variable_quantile_value.lower(): - pct = float(agg_change.filter_variable_quantile_value.lower().replace("bottom_", "").replace("%", "")) / 100 - threshold = filter_series.quantile(pct) - mask &= filter_series <= threshold - return mask @staticmethod @@ -350,7 +218,7 @@ def _compute_single_aggregate_from_tables( # Map it to the target entity level try: - mapped_series = AggregateChange._map_variable_across_entities( + mapped_series = AggregateUtils.map_variable_across_entities( source_table, agg_change.variable_name, variable_entity, @@ -377,155 +245,9 @@ def _compute_single_aggregate_from_tables( if len(df) == 0: return 0.0 - try: - if agg_change.aggregate_function == AggregateType.SUM: - value = float(df[agg_change.variable_name].sum()) - elif agg_change.aggregate_function == AggregateType.MEAN: - value = float(df[agg_change.variable_name].mean()) - elif agg_change.aggregate_function == AggregateType.MEDIAN: - value = float(df[agg_change.variable_name].median()) - elif agg_change.aggregate_function == AggregateType.COUNT: - value = float((df[agg_change.variable_name] > 0).sum()) - else: - raise ValueError(f"Unknown aggregate function: {agg_change.aggregate_function}") - except (ZeroDivisionError, ValueError) as e: - return 0.0 - - return value - - @staticmethod - def _get_filter_mask( - agg_change: "AggregateChange", simulation: "Simulation" - ) -> pd.Series | None: - """Get filter mask based on baseline simulation values.""" - if agg_change.filter_variable_name is None: - return None # No filtering needed - - tables = simulation.result - tables = {k: v.copy() for k, v in tables.items()} - - for table in tables: - tables[table] = pd.DataFrame(tables[table]) - weight_col = f"{table}_weight" - if weight_col in tables[table].columns: - tables[table] = MicroDataFrame( - tables[table], weights=weight_col - ) - - if agg_change.entity not in tables: - raise ValueError( - f"Entity {agg_change.entity} not found in simulation results" - ) - - df = tables[agg_change.entity] - - if agg_change.filter_variable_name not in df.columns: - raise ValueError( - f"Filter variable {agg_change.filter_variable_name} not found in entity {agg_change.entity}" - ) - - # Create filter mask based on baseline values - mask = pd.Series([True] * len(df), index=df.index) - - # Apply value/range filters - if agg_change.filter_variable_value is not None: - mask &= df[agg_change.filter_variable_name] == agg_change.filter_variable_value - - if agg_change.filter_variable_leq is not None: - mask &= df[agg_change.filter_variable_name] <= agg_change.filter_variable_leq - - if agg_change.filter_variable_geq is not None: - mask &= df[agg_change.filter_variable_name] >= agg_change.filter_variable_geq - - # Apply quantile filters if specified - if any([agg_change.filter_variable_quantile_leq, - agg_change.filter_variable_quantile_geq, agg_change.filter_variable_quantile_value]): - - if agg_change.filter_variable_quantile_leq is not None: - # Filter to values <= specified quantile - threshold = df[agg_change.filter_variable_name].quantile(agg_change.filter_variable_quantile_leq) - mask &= df[agg_change.filter_variable_name] <= threshold - - if agg_change.filter_variable_quantile_geq is not None: - # Filter to values >= specified quantile - threshold = df[agg_change.filter_variable_name].quantile(agg_change.filter_variable_quantile_geq) - mask &= df[agg_change.filter_variable_name] >= threshold - - if agg_change.filter_variable_quantile_value is not None: - # Parse quantile value like "top_10%" or "bottom_20%" - if "top" in agg_change.filter_variable_quantile_value.lower(): - pct = float(agg_change.filter_variable_quantile_value.lower().replace("top_", "").replace("%", "")) / 100 - threshold = df[agg_change.filter_variable_name].quantile(1 - pct) - mask &= df[agg_change.filter_variable_name] >= threshold - elif "bottom" in agg_change.filter_variable_quantile_value.lower(): - pct = float(agg_change.filter_variable_quantile_value.lower().replace("bottom_", "").replace("%", "")) / 100 - threshold = df[agg_change.filter_variable_name].quantile(pct) - mask &= df[agg_change.filter_variable_name] <= threshold - - return mask - - @staticmethod - def _compute_single_aggregate( - agg_change: "AggregateChange", - simulation: "Simulation", - filter_mask: pd.Series | None = None - ) -> float: - """Compute aggregate value for a single simulation.""" - compute_start = time.time() - tables = simulation.result - # Copy tables to ensure we don't modify original dataframes - tables = {k: v.copy() for k, v in tables.items()} - - for table in tables: - tables[table] = pd.DataFrame(tables[table]) - weight_col = f"{table}_weight" - if weight_col in tables[table].columns: - tables[table] = MicroDataFrame( - tables[table], weights=weight_col - ) - - if agg_change.entity not in tables: - raise ValueError( - f"Entity {agg_change.entity} not found in simulation results" - ) - - table = tables[agg_change.entity] - - if agg_change.variable_name not in table.columns: - raise ValueError( - f"Variable {agg_change.variable_name} not found in entity {agg_change.entity}" - ) - - df = table - - if agg_change.year is None: - agg_change.year = simulation.dataset.year - - # Apply the pre-computed filter mask if provided - # This ensures we're using the same subset of entities for both baseline and comparison - if filter_mask is not None: - df = df[filter_mask] - - # Check if we have any data left after filtering - if len(df) == 0: - # Return 0 for empty datasets - return 0.0 - - # Compute aggregate - try: - if agg_change.aggregate_function == AggregateType.SUM: - value = float(df[agg_change.variable_name].sum()) - elif agg_change.aggregate_function == AggregateType.MEAN: - value = float(df[agg_change.variable_name].mean()) - elif agg_change.aggregate_function == AggregateType.MEDIAN: - value = float(df[agg_change.variable_name].median()) - elif agg_change.aggregate_function == AggregateType.COUNT: - value = float((df[agg_change.variable_name] > 0).sum()) - else: - raise ValueError(f"Unknown aggregate function: {agg_change.aggregate_function}") - except (ZeroDivisionError, ValueError) as e: - # Handle cases where weights sum to zero or other computation errors - # Return 0 for these edge cases - return 0.0 + # Use centralized compute function + return AggregateUtils.compute_aggregate( + df[agg_change.variable_name], + agg_change.aggregate_function + ) - return value \ No newline at end of file diff --git a/tests/test_aggregate_utils.py b/tests/test_aggregate_utils.py new file mode 100644 index 00000000..39090dc6 --- /dev/null +++ b/tests/test_aggregate_utils.py @@ -0,0 +1,441 @@ +""" +Unit tests for AggregateUtils cross-entity mapping functionality. + +These tests verify that variables can be correctly mapped between entities +(person <-> household, person <-> tax_unit, etc.) while preserving weights. +""" + +import pytest +import pandas as pd +from microdf import MicroDataFrame +from policyengine.models.aggregate import AggregateUtils, AggregateType + + +class MockSimulation: + """Mock simulation for testing.""" + def __init__(self, result): + self.result = result + + +@pytest.fixture +def sample_person_household_tables(): + """ + Create sample tables with person and household entities. + + Structure: + - 2 households (ids 0, 1) + - 4 persons (ids 0, 1, 2, 3) + - Persons 0, 1 belong to household 0 + - Persons 2, 3 belong to household 1 + """ + person_table = pd.DataFrame({ + 'person_id': [0, 1, 2, 3], + 'person_household_id': [0, 0, 1, 1], + 'person_weight': [1.0, 1.0, 1.0, 1.0], + 'age': [30, 5, 45, 40], + 'employment_income': [50000, 0, 60000, 55000], + }) + + household_table = pd.DataFrame({ + 'household_id': [0, 1], + 'household_weight': [1.0, 1.0], + 'household_net_income': [50000, 115000], + 'is_in_poverty': [1, 0], # household 0 is in poverty, household 1 is not + }) + + return { + 'person': person_table, + 'household': household_table + } + + +@pytest.fixture +def weighted_person_household_tables(): + """ + Create sample tables with different weights to test weight preservation. + + Structure: + - 2 households with different weights + - 4 persons with weights matching their households + """ + person_table = pd.DataFrame({ + 'person_id': [0, 1, 2, 3], + 'person_household_id': [0, 0, 1, 1], + 'person_weight': [100.0, 100.0, 200.0, 200.0], # Different weights + 'age': [30, 5, 45, 40], + 'employment_income': [50000, 0, 60000, 55000], + }) + + household_table = pd.DataFrame({ + 'household_id': [0, 1], + 'household_weight': [100.0, 200.0], # Different weights + 'household_net_income': [50000, 115000], + 'is_in_poverty': [1, 0], + }) + + return { + 'person': person_table, + 'household': household_table + } + + +class TestPrepareTablesWithWeights: + """Test that prepare_tables correctly creates MicroDataFrames with weights.""" + + def test_prepare_tables_creates_microdataframes(self, sample_person_household_tables): + """Test that tables with weight columns become MicroDataFrames.""" + mock_sim = MockSimulation(sample_person_household_tables) + tables = AggregateUtils.prepare_tables(mock_sim) + + # Both tables should be MicroDataFrames + assert isinstance(tables['person'], MicroDataFrame) + assert isinstance(tables['household'], MicroDataFrame) + + # Check that weights are set correctly + assert tables['person'].weights_col == 'person_weight' + assert tables['household'].weights_col == 'household_weight' + + def test_prepare_tables_without_weights(self): + """Test that tables without weight columns remain regular DataFrames.""" + tables_without_weights = { + 'person': pd.DataFrame({ + 'person_id': [0, 1], + 'age': [30, 40] + }) + } + mock_sim = MockSimulation(tables_without_weights) + tables = AggregateUtils.prepare_tables(mock_sim) + + # Should be regular DataFrame, not MicroDataFrame + assert isinstance(tables['person'], pd.DataFrame) + assert not isinstance(tables['person'], MicroDataFrame) + + +class TestInferEntity: + """Test entity inference from variable names.""" + + def test_infer_entity_person_variable(self, sample_person_household_tables): + """Test inferring entity for a person-level variable.""" + entity = AggregateUtils.infer_entity('age', sample_person_household_tables) + assert entity == 'person' + + def test_infer_entity_household_variable(self, sample_person_household_tables): + """Test inferring entity for a household-level variable.""" + entity = AggregateUtils.infer_entity('is_in_poverty', sample_person_household_tables) + assert entity == 'household' + + def test_infer_entity_nonexistent_variable(self, sample_person_household_tables): + """Test that nonexistent variable raises ValueError.""" + with pytest.raises(ValueError, match="Variable nonexistent not found"): + AggregateUtils.infer_entity('nonexistent', sample_person_household_tables) + + +class TestMapVariableAcrossEntities: + """Test cross-entity variable mapping.""" + + def test_map_household_to_person(self, sample_person_household_tables): + """ + Test mapping a household variable to person level. + Each person should get their household's value. + """ + mapped = AggregateUtils.map_variable_across_entities( + sample_person_household_tables['household'], + 'is_in_poverty', + 'household', + 'person', + sample_person_household_tables + ) + + # Persons 0, 1 belong to household 0 (in poverty) + assert mapped.iloc[0] == 1 + assert mapped.iloc[1] == 1 + + # Persons 2, 3 belong to household 1 (not in poverty) + assert mapped.iloc[2] == 0 + assert mapped.iloc[3] == 0 + + def test_map_person_to_household(self, sample_person_household_tables): + """ + Test mapping a person variable to household level. + Should sum persons' values within each household. + """ + mapped = AggregateUtils.map_variable_across_entities( + sample_person_household_tables['person'], + 'employment_income', + 'person', + 'household', + sample_person_household_tables + ) + + # Household 0: persons 0 (50000) + 1 (0) = 50000 + assert mapped.iloc[0] == 50000 + + # Household 1: persons 2 (60000) + 3 (55000) = 115000 + assert mapped.iloc[1] == 115000 + + def test_map_same_entity(self, sample_person_household_tables): + """Test that mapping to same entity returns the variable as-is.""" + mapped = AggregateUtils.map_variable_across_entities( + sample_person_household_tables['person'], + 'age', + 'person', + 'person', + sample_person_household_tables + ) + + pd.testing.assert_series_equal( + mapped, + sample_person_household_tables['person']['age'], + check_names=False + ) + + def test_map_preserves_length(self, sample_person_household_tables): + """Test that mapped series has correct length for target entity.""" + # Household to person: should have 4 entries (4 persons) + mapped_h_to_p = AggregateUtils.map_variable_across_entities( + sample_person_household_tables['household'], + 'is_in_poverty', + 'household', + 'person', + sample_person_household_tables + ) + assert len(mapped_h_to_p) == 4 + + # Person to household: should have 2 entries (2 households) + mapped_p_to_h = AggregateUtils.map_variable_across_entities( + sample_person_household_tables['person'], + 'employment_income', + 'person', + 'household', + sample_person_household_tables + ) + assert len(mapped_p_to_h) == 2 + + +class TestComputeAggregate: + """Test aggregate computation functions.""" + + def test_sum_simple(self): + """Test simple sum aggregation.""" + series = pd.Series([10, 20, 30, 40]) + result = AggregateUtils.compute_aggregate(series, AggregateType.SUM) + assert result == 100.0 + + def test_sum_weighted(self): + """Test weighted sum using MicroDataFrame.""" + df = pd.DataFrame({ + 'value': [10, 20, 30], + 'weight': [1.0, 2.0, 1.0] + }) + mdf = MicroDataFrame(df, weights='weight') + result = AggregateUtils.compute_aggregate(mdf['value'], AggregateType.SUM) + + # Weighted sum: 10*1 + 20*2 + 30*1 = 80 + assert result == 80.0 + + def test_mean_simple(self): + """Test simple mean aggregation.""" + series = pd.Series([10, 20, 30, 40]) + result = AggregateUtils.compute_aggregate(series, AggregateType.MEAN) + assert result == 25.0 + + def test_mean_weighted(self): + """Test weighted mean using MicroDataFrame.""" + df = pd.DataFrame({ + 'value': [10, 20, 30], + 'weight': [1.0, 2.0, 1.0] + }) + mdf = MicroDataFrame(df, weights='weight') + result = AggregateUtils.compute_aggregate(mdf['value'], AggregateType.MEAN) + + # Weighted mean: (10*1 + 20*2 + 30*1) / (1+2+1) = 80/4 = 20 + assert result == 20.0 + + def test_median_simple(self): + """Test simple median aggregation.""" + series = pd.Series([10, 20, 30, 40, 50]) + result = AggregateUtils.compute_aggregate(series, AggregateType.MEDIAN) + assert result == 30.0 + + def test_count(self): + """Test count aggregation (counts all entries).""" + series = pd.Series([0, 10, 0, 20, 30, 0]) + result = AggregateUtils.compute_aggregate(series, AggregateType.COUNT) + # COUNT returns the total number of entries, not just non-zero + assert result == 6.0 + + # To count only non-zero values, filter first then count + non_zero = series[series > 0] + result_filtered = AggregateUtils.compute_aggregate(non_zero, AggregateType.COUNT) + assert result_filtered == 3.0 + + def test_empty_series(self): + """Test that empty series returns 0.""" + series = pd.Series([]) + for agg_type in [AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT]: + result = AggregateUtils.compute_aggregate(series, agg_type) + assert result == 0.0 + + +class TestPovertyRateScenario: + """ + Test the specific poverty rate scenario that was giving 1% result. + + This tests the complete flow: prepare tables, map variables, apply filters, + and compute aggregates with weights. + """ + + def test_poverty_rate_with_household_filter_person_aggregation(self, weighted_person_household_tables): + """ + Test computing poverty rate at person level with household-level filter. + + Scenario: + - Filter: households in poverty (is_in_poverty == 1) + - Variable: count of persons + - This should count persons in poor households + """ + # Prepare tables as they would be in production + mock_sim = MockSimulation(weighted_person_household_tables) + tables = AggregateUtils.prepare_tables(mock_sim) + + # Step 1: Get household filter variable and map to person level + household_df = tables['household'] + filter_variable = 'is_in_poverty' + + # Map household filter to person level + mapped_filter = AggregateUtils.map_variable_across_entities( + household_df, + filter_variable, + 'household', + 'person', + tables + ) + + # Build filter mask at person level + person_table = tables['person'] + mask = mapped_filter == 1 + + # Step 2: Filter the person table + filtered_table = person_table[mask] + + # Step 3: Count persons (weighted) + # Persons 0 and 1 are in poor households, each with weight 100 + count = AggregateUtils.compute_aggregate( + filtered_table['person_id'], + AggregateType.COUNT + ) + + # Should count 2 persons (weighted count with MicroDataFrame should be 200) + # But COUNT just counts entries > 0, not weighted + assert count == 2.0 + + # For weighted sum of persons in poverty: + sum_weights = AggregateUtils.compute_aggregate( + filtered_table['person_weight'], + AggregateType.SUM + ) + assert sum_weights == 200.0 # 100 + 100 + + def test_poverty_rate_household_level(self, weighted_person_household_tables): + """ + Test computing poverty rate at household level. + + This is more straightforward - just filter households and count. + """ + mock_sim = MockSimulation(weighted_person_household_tables) + tables = AggregateUtils.prepare_tables(mock_sim) + + household_table = tables['household'] + + # Filter to households in poverty + mask = household_table['is_in_poverty'] == 1 + filtered = household_table[mask] + + # Count households + count = AggregateUtils.compute_aggregate( + filtered['is_in_poverty'], + AggregateType.COUNT + ) + assert count == 1.0 # Only 1 household in poverty + + # Weighted sum + sum_weights = AggregateUtils.compute_aggregate( + filtered['household_weight'], + AggregateType.SUM + ) + assert sum_weights == 100.0 # Weight of household 0 + + def test_mean_income_in_poor_households(self, weighted_person_household_tables): + """ + Test computing mean income for persons in poor households. + + This tests the complete cross-entity flow with weights. + """ + mock_sim = MockSimulation(weighted_person_household_tables) + tables = AggregateUtils.prepare_tables(mock_sim) + + # Step 1: Map household poverty status to person level + mapped_poverty = AggregateUtils.map_variable_across_entities( + tables['household'], + 'is_in_poverty', + 'household', + 'person', + tables + ) + + # Step 2: Filter persons in poor households + person_table = tables['person'] + mask = mapped_poverty == 1 + filtered_table = person_table[mask] + + # Step 3: Compute mean employment income + mean_income = AggregateUtils.compute_aggregate( + filtered_table['employment_income'], + AggregateType.MEAN + ) + + # Persons 0 (income 50000, weight 100) and 1 (income 0, weight 100) + # Weighted mean: (50000*100 + 0*100) / (100+100) = 25000 + assert mean_income == 25000.0 + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_missing_link_column(self): + """Test error when link column is missing.""" + tables = { + 'person': pd.DataFrame({'age': [30, 40]}), + 'household': pd.DataFrame({'income': [50000, 60000]}) + } + + with pytest.raises(ValueError, match="Link column .* not found"): + AggregateUtils.map_variable_across_entities( + tables['household'], + 'income', + 'household', + 'person', + tables + ) + + def test_unknown_aggregate_function(self): + """Test error with unknown aggregate function.""" + series = pd.Series([10, 20, 30]) + + with pytest.raises(ValueError, match="Unknown aggregate function"): + AggregateUtils.compute_aggregate(series, 'unknown_function') + + def test_map_with_missing_entity(self): + """Test error when entity doesn't exist in tables.""" + tables = { + 'person': pd.DataFrame({'age': [30, 40]}) + } + + with pytest.raises(ValueError, match="No known link"): + AggregateUtils.map_variable_across_entities( + tables['person'], + 'age', + 'person', + 'nonexistent_entity', + tables + ) From db7ebb24bbd8f4d96732af9ffc797bb21103f7f4 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Thu, 23 Oct 2025 13:22:59 +0100 Subject: [PATCH 11/35] Update --- src/policyengine/models/aggregate.py | 464 +++++++++---------- src/policyengine/models/aggregate_change.py | 276 +++-------- tests/test_aggregate.py | 442 ++++++++++++++++++ tests/test_aggregate_change.py | 479 ++++++++++++++++++++ tests/test_aggregate_utils.py | 441 ------------------ 5 files changed, 1206 insertions(+), 896 deletions(-) create mode 100644 tests/test_aggregate.py create mode 100644 tests/test_aggregate_change.py delete mode 100644 tests/test_aggregate_utils.py diff --git a/src/policyengine/models/aggregate.py b/src/policyengine/models/aggregate.py index bf8e110b..166bcfae 100644 --- a/src/policyengine/models/aggregate.py +++ b/src/policyengine/models/aggregate.py @@ -1,10 +1,10 @@ from enum import Enum -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal from uuid import uuid4 import pandas as pd from microdf import MicroDataFrame -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, SkipValidation if TYPE_CHECKING: from policyengine.models import Simulation @@ -17,158 +17,207 @@ class AggregateType(str, Enum): COUNT = "count" -class AggregateUtils: - """Shared utilities for aggregate calculations.""" +class DataEngine: + """Clean data processing engine for aggregations.""" - @staticmethod - def prepare_tables(simulation: "Simulation") -> dict: - """Prepare dataframes from simulation result once.""" - tables = simulation.result - tables = {k: v.copy() for k, v in tables.items()} - - for table in tables: - tables[table] = pd.DataFrame(tables[table]) - weight_col = f"{table}_weight" - if weight_col in tables[table].columns: - tables[table] = MicroDataFrame( - tables[table], weights=weight_col - ) - - return tables + def __init__(self, tables: dict): + """Initialize with simulation result tables.""" + self.tables = self._prepare_tables(tables) @staticmethod - def infer_entity(variable_name: str, tables: dict) -> str: - """Infer entity from variable name by checking which table contains it.""" - for entity, table in tables.items(): - if variable_name in table.columns: + def _prepare_tables(tables: dict) -> dict[str, pd.DataFrame]: + """Convert tables to DataFrames with MicroDataFrame for weighted columns.""" + prepared = {} + for name, table in tables.items(): + df = pd.DataFrame(table.copy() if hasattr(table, 'copy') else table) + weight_col = f"{name}_weight" + if weight_col in df.columns: + df = MicroDataFrame(df, weights=weight_col) + prepared[name] = df + return prepared + + def infer_entity(self, variable: str) -> str: + """Infer which entity contains a variable.""" + for entity, table in self.tables.items(): + if variable in table.columns: return entity - raise ValueError(f"Variable {variable_name} not found in any entity table") - - @staticmethod - def get_entity_link_columns() -> dict: - """Return mapping of entity relationships for common PolicyEngine models.""" - return { - # person -> group entity links (copy values down) - "person": { - "benunit": "person_benunit_id", - "household": "person_household_id", - "family": "person_family_id", - "tax_unit": "person_tax_unit_id", - "spm_unit": "person_spm_unit_id", - }, - } + raise ValueError(f"Variable {variable} not found in any entity") - @staticmethod - def map_variable_across_entities( - df: pd.DataFrame, - variable_name: str, - source_entity: str, + def get_variable_series( + self, + variable: str, target_entity: str, - tables: dict + filters: dict[str, Any] | None = None ) -> pd.Series: - """Map a variable from source entity to target entity level.""" - links = AggregateUtils.get_entity_link_columns() + """ + Get variable series at target entity level, with optional filtering. + + Handles cross-entity mapping automatically. + """ + # Find source entity + source_entity = self.infer_entity(variable) + + # Apply filters first (on target entity) + if filters: + mask = self._build_filter_mask(filters, target_entity) + target_table = self.tables[target_entity][mask] + else: + target_table = self.tables[target_entity] - # Group to person: copy group values to persons using link column - if source_entity != "person" and target_entity == "person": - link_col = links.get("person", {}).get(source_entity) - if link_col is None: - raise ValueError(f"No known link from person to {source_entity}") + # Get variable (map if needed) + if source_entity == target_entity: + return target_table[variable] + else: + # Map across entities + source_series = self.tables[source_entity][variable] + mapped_series = self._map_variable(source_series, source_entity, target_entity) + # Apply filter mask to mapped series + if filters: + return mapped_series[mask] + return mapped_series + + def _build_filter_mask(self, filters: dict[str, Any], target_entity: str) -> pd.Series: + """Build boolean mask from filter specification.""" + target_table = self.tables[target_entity] + mask = pd.Series([True] * len(target_table), index=target_table.index) + + filter_variable = filters.get('variable') + if not filter_variable: + return mask + + # Get filter series (map if cross-entity) + filter_entity = self.infer_entity(filter_variable) + if filter_entity == target_entity: + filter_series = target_table[filter_variable] + else: + filter_series = self._map_variable( + self.tables[filter_entity][filter_variable], + filter_entity, + target_entity + ) - if link_col not in tables["person"].columns: - raise ValueError(f"Link column {link_col} not found in person table") + # Apply value filters + if 'value' in filters and filters['value'] is not None: + mask &= filter_series == filters['value'] - # Create mapping: group position (0-based index) -> value - group_values = df[variable_name].values + if 'leq' in filters and filters['leq'] is not None: + mask &= filter_series <= filters['leq'] - # Map to person level using the link column - person_table = tables["person"] - person_group_ids = person_table[link_col].values + if 'geq' in filters and filters['geq'] is not None: + mask &= filter_series >= filters['geq'] - # Map each person to their group's value - result = pd.Series([group_values[int(gid)] if int(gid) < len(group_values) else 0 - for gid in person_group_ids], index=person_table.index) - return result + # Apply quantile filters + if 'quantile_leq' in filters and filters['quantile_leq'] is not None: + threshold = filter_series.quantile(filters['quantile_leq']) + mask &= filter_series <= threshold - # Person to group: sum persons' values to group level - elif source_entity == "person" and target_entity != "person": - link_col = links.get("person", {}).get(target_entity) - if link_col is None: - raise ValueError(f"No known link from person to {target_entity}") - - if link_col not in df.columns: - raise ValueError(f"Link column {link_col} not found in person table") - - # Sum by group - need to align with group table length - grouped = df.groupby(link_col)[variable_name].sum() - - # Create a series aligned with the group table - group_table = tables[target_entity] - result = pd.Series([grouped.get(i, 0) for i in range(len(group_table))], - index=group_table.index) - return result - - # Group to group: try via person as intermediary - elif source_entity != "person" and target_entity != "person": - # Map source -> person -> target - person_values = AggregateUtils.map_variable_across_entities( - df, variable_name, source_entity, "person", tables + if 'quantile_geq' in filters and filters['quantile_geq'] is not None: + threshold = filter_series.quantile(filters['quantile_geq']) + mask &= filter_series >= threshold + + return mask + + def _map_variable( + self, + series: pd.Series, + source_entity: str, + target_entity: str + ) -> pd.Series: + """Map a variable from source to target entity.""" + if source_entity == target_entity: + return series + + # Default entity links (can be overridden) + person_links = { + "benunit": "person_benunit_id", + "household": "person_household_id", + "family": "person_family_id", + "tax_unit": "person_tax_unit_id", + "spm_unit": "person_spm_unit_id", + } + + # Group to person: copy values down + if source_entity != "person" and target_entity == "person": + link_col = person_links.get(source_entity) + if not link_col: + raise ValueError(f"No link from person to {source_entity}") + + person_table = self.tables["person"] + if link_col not in person_table.columns: + raise ValueError(f"Link column {link_col} not in person table") + + group_values = series.values + person_group_ids = person_table[link_col].values + return pd.Series( + [group_values[int(gid)] if int(gid) < len(group_values) else 0 + for gid in person_group_ids], + index=person_table.index ) - # Create temp dataframe with person values - temp_person_df = tables["person"].copy() - temp_person_df[variable_name] = person_values - return AggregateUtils.map_variable_across_entities( - temp_person_df, variable_name, "person", target_entity, tables + # Person to group: aggregate up + elif source_entity == "person" and target_entity != "person": + link_col = person_links.get(target_entity) + if not link_col: + raise ValueError(f"No link from person to {target_entity}") + + person_table = self.tables["person"] + if link_col not in person_table.columns: + raise ValueError(f"Link column {link_col} not in person table") + + grouped = pd.DataFrame({ + link_col: person_table[link_col], + 'value': series + }).groupby(link_col)['value'].sum() + + target_table = self.tables[target_entity] + return pd.Series( + [grouped.get(i, 0) for i in range(len(target_table))], + index=target_table.index ) + # Group to group: via person else: - # Same entity - shouldn't happen but return as-is - return df[variable_name] + person_series = self._map_variable(series, source_entity, "person") + return self._map_variable(person_series, "person", target_entity) @staticmethod - def compute_aggregate( - variable_series: pd.Series | MicroDataFrame, - aggregate_function: str - ) -> float: - """Compute aggregate value from a series.""" - if len(variable_series) == 0: + def aggregate(series: pd.Series, function: AggregateType) -> float: + """Apply aggregation function to series.""" + if len(series) == 0: return 0.0 - # Check if this is a weight column being summed from a MicroDataFrame - # If so, use unweighted sum to avoid weight^2 - is_weight_column = ( - isinstance(variable_series, pd.Series) and - hasattr(variable_series, 'name') and - variable_series.name and - 'weight' in str(variable_series.name).lower() + # Avoid double-weighting weight columns + is_weight = ( + hasattr(series, 'name') and + series.name and + 'weight' in str(series.name).lower() ) - if aggregate_function == AggregateType.SUM: - if is_weight_column: - # Use unweighted sum for weight columns - return float(pd.Series(variable_series.values).sum()) - return float(variable_series.sum()) - elif aggregate_function == AggregateType.MEAN: - return float(variable_series.mean()) - elif aggregate_function == AggregateType.MEDIAN: - return float(variable_series.median()) - elif aggregate_function == AggregateType.COUNT: - # For COUNT, return the actual number of entries (not weighted) - # Use len() to count all entries regardless of value - return float(len(variable_series)) + if function == AggregateType.SUM: + if is_weight: + return float(pd.Series(series.values).sum()) + return float(series.sum()) + elif function == AggregateType.MEAN: + return float(series.mean()) + elif function == AggregateType.MEDIAN: + return float(series.median()) + elif function == AggregateType.COUNT: + return float(len(series)) else: - raise ValueError(f"Unknown aggregate function: {aggregate_function}") + raise ValueError(f"Unknown aggregate function: {function}") class Aggregate(BaseModel): + """Aggregate calculation.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + id: str = Field(default_factory=lambda: str(uuid4())) - simulation: "Simulation | None" = None + simulation: SkipValidation["Simulation | None"] = None entity: str | None = None variable_name: str year: int | None = None filter_variable_name: str | None = None - filter_variable_value: str | None = None + filter_variable_value: Any | None = None filter_variable_leq: float | None = None filter_variable_geq: float | None = None filter_variable_quantile_leq: float | None = None @@ -177,154 +226,61 @@ class Aggregate(BaseModel): AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT ] reportelement_id: str | None = None - value: float | None = None @staticmethod def run(aggregates: list["Aggregate"]) -> list["Aggregate"]: - """Process aggregates, handling multiple simulations if necessary.""" - # Group aggregates by simulation - simulation_groups = {} + """Process aggregates efficiently by batching those with same simulation.""" + # Group by simulation for batch processing + by_simulation = {} for agg in aggregates: sim_id = id(agg.simulation) if agg.simulation else None - if sim_id not in simulation_groups: - simulation_groups[sim_id] = [] - simulation_groups[sim_id].append(agg) + if sim_id not in by_simulation: + by_simulation[sim_id] = [] + by_simulation[sim_id].append(agg) - # Process each simulation group separately - all_results = [] - for sim_id, sim_aggregates in simulation_groups.items(): + results = [] + for sim_aggregates in by_simulation.values(): if not sim_aggregates: continue - # Get the simulation from the first aggregate in this group simulation = sim_aggregates[0].simulation if simulation is None: - raise ValueError("Aggregate has no simulation attached") - - # Process this simulation's aggregates - group_results = Aggregate._process_simulation_aggregates( - sim_aggregates, simulation - ) - all_results.extend(group_results) - - return all_results - - @staticmethod - def _process_simulation_aggregates( - aggregates: list["Aggregate"], simulation: "Simulation" - ) -> list["Aggregate"]: - """Process aggregates for a single simulation.""" - results = [] - - # Use centralized table preparation - tables = AggregateUtils.prepare_tables(simulation) - - for agg in aggregates: - # Infer entity if not provided - if agg.entity is None: - agg.entity = AggregateUtils.infer_entity(agg.variable_name, tables) - - if agg.entity not in tables: - raise ValueError( - f"Entity {agg.entity} not found in simulation results" - ) - - if agg.year is None: - agg.year = simulation.dataset.year - - # Get the target table - target_table = tables[agg.entity] - - # Handle cross-entity filters - mask = None - if agg.filter_variable_name is not None: - # Find which entity contains the filter variable - filter_entity = None - for entity, table in tables.items(): - if agg.filter_variable_name in table.columns: - filter_entity = entity - break - - if filter_entity is None: - raise ValueError( - f"Filter variable {agg.filter_variable_name} not found in any entity" - ) - - # Get the filter series (mapped if needed) - if filter_entity == agg.entity: - filter_series = tables[agg.entity][agg.filter_variable_name] - else: - # Different entity - map filter variable to target entity - filter_df = tables[filter_entity] - filter_series = AggregateUtils.map_variable_across_entities( - filter_df, - agg.filter_variable_name, - filter_entity, - agg.entity, - tables - ) - - # Build mask - mask = pd.Series([True] * len(target_table), index=target_table.index) - - # Apply value/range filters - if agg.filter_variable_value is not None: - mask &= filter_series == agg.filter_variable_value - if agg.filter_variable_leq is not None: - mask &= filter_series <= agg.filter_variable_leq - if agg.filter_variable_geq is not None: - mask &= filter_series >= agg.filter_variable_geq - - # Apply quantile filters - if agg.filter_variable_quantile_leq is not None: - threshold = filter_series.quantile(agg.filter_variable_quantile_leq) - mask &= filter_series <= threshold - if agg.filter_variable_quantile_geq is not None: - threshold = filter_series.quantile(agg.filter_variable_quantile_geq) - mask &= filter_series >= threshold - - # Find which entity contains the variable - variable_entity = None - for entity, table in tables.items(): - if agg.variable_name in table.columns: - variable_entity = entity - break - - if variable_entity is None: - raise ValueError( - f"Variable {agg.variable_name} not found in any entity" - ) - - # Get variable data (mapped if needed) - if variable_entity == agg.entity: - # Same entity - extract from target table - if mask is not None: - # Filter the entire table to preserve MicroDataFrame weights - filtered_table = target_table[mask] - variable_series = filtered_table[agg.variable_name] - else: - variable_series = target_table[agg.variable_name] - else: - # Map variable to target entity - source_table = tables[variable_entity] - variable_series = AggregateUtils.map_variable_across_entities( - source_table, + raise ValueError("Aggregate has no simulation") + + # Create data engine once per simulation (batch optimization) + engine = DataEngine(simulation.result) + + # Process each aggregate + for agg in sim_aggregates: + if agg.year is None: + agg.year = simulation.dataset.year + + # Infer entity if not specified + if agg.entity is None: + agg.entity = engine.infer_entity(agg.variable_name) + + # Build filter specification + filters = None + if agg.filter_variable_name: + filters = { + 'variable': agg.filter_variable_name, + 'value': agg.filter_variable_value, + 'leq': agg.filter_variable_leq, + 'geq': agg.filter_variable_geq, + 'quantile_leq': agg.filter_variable_quantile_leq, + 'quantile_geq': agg.filter_variable_quantile_geq, + } + + # Get variable series with filters + series = engine.get_variable_series( agg.variable_name, - variable_entity, agg.entity, - tables + filters ) - # Apply mask after mapping - if mask is not None: - variable_series = variable_series[mask] - - # Compute aggregate using centralized function - agg.value = AggregateUtils.compute_aggregate( - variable_series, - agg.aggregate_function - ) - results.append(agg) + # Compute aggregate + agg.value = engine.aggregate(series, agg.aggregate_function) + results.append(agg) return results diff --git a/src/policyengine/models/aggregate_change.py b/src/policyengine/models/aggregate_change.py index 49c05266..e869d8fc 100644 --- a/src/policyengine/models/aggregate_change.py +++ b/src/policyengine/models/aggregate_change.py @@ -1,26 +1,26 @@ -from enum import Enum -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal from uuid import uuid4 -import pandas as pd -from pydantic import BaseModel, Field -import time +from pydantic import BaseModel, ConfigDict, Field, SkipValidation -from .aggregate import AggregateUtils, AggregateType +from .aggregate import AggregateType, DataEngine if TYPE_CHECKING: from policyengine.models import Simulation class AggregateChange(BaseModel): + """Calculates the change in an aggregate between baseline and comparison simulations.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + id: str = Field(default_factory=lambda: str(uuid4())) - baseline_simulation: "Simulation | None" = None - comparison_simulation: "Simulation | None" = None + baseline_simulation: SkipValidation["Simulation | None"] = None + comparison_simulation: SkipValidation["Simulation | None"] = None entity: str | None = None variable_name: str year: int | None = None filter_variable_name: str | None = None - filter_variable_value: str | None = None + filter_variable_value: Any | None = None filter_variable_leq: float | None = None filter_variable_geq: float | None = None filter_variable_quantile_leq: float | None = None @@ -35,219 +35,93 @@ class AggregateChange(BaseModel): change: float | None = None relative_change: float | None = None - @staticmethod def run(aggregate_changes: list["AggregateChange"]) -> list["AggregateChange"]: - """Process aggregate changes, batching those with the same simulation pair.""" - start_time = time.time() - print(f"[PERFORMANCE] AggregateChange.run starting with {len(aggregate_changes)} items") - - # Group aggregate changes by simulation pair for batch processing - from collections import defaultdict - grouped = defaultdict(list) + """Process aggregate changes efficiently by batching those with same simulation pair.""" + # Group by simulation pair for batch processing + by_sim_pair = {} for agg_change in aggregate_changes: if agg_change.baseline_simulation is None: - raise ValueError("AggregateChange has no baseline simulation attached") + raise ValueError("AggregateChange missing baseline_simulation") if agg_change.comparison_simulation is None: - raise ValueError("AggregateChange has no comparison simulation attached") + raise ValueError("AggregateChange missing comparison_simulation") - key = (agg_change.baseline_simulation.id, agg_change.comparison_simulation.id) - grouped[key].append(agg_change) - - print(f"[PERFORMANCE] Grouped {len(aggregate_changes)} items into {len(grouped)} simulation pairs") + key = ( + id(agg_change.baseline_simulation), + id(agg_change.comparison_simulation) + ) + if key not in by_sim_pair: + by_sim_pair[key] = [] + by_sim_pair[key].append(agg_change) results = [] + for pair_aggregates in by_sim_pair.values(): + if not pair_aggregates: + continue - for (baseline_id, comparison_id), group in grouped.items(): - group_start = time.time() - print(f"[PERFORMANCE] Processing batch of {len(group)} items for sim pair {baseline_id[:8]}...{comparison_id[:8]}") - - # Get simulation objects once for the group - baseline_sim = group[0].baseline_simulation - comparison_sim = group[0].comparison_simulation + # Get simulation objects + baseline_sim = pair_aggregates[0].baseline_simulation + comparison_sim = pair_aggregates[0].comparison_simulation - # Pre-compute simulation dataframes once per batch - baseline_tables = AggregateUtils.prepare_tables(baseline_sim) - comparison_tables = AggregateUtils.prepare_tables(comparison_sim) + # Create data engines once per simulation pair (batch optimization) + baseline_engine = DataEngine(baseline_sim.result) + comparison_engine = DataEngine(comparison_sim.result) - prep_time = time.time() - print(f"[PERFORMANCE] Table preparation took {prep_time - group_start:.3f} seconds") + # Process each aggregate change + for agg_change in pair_aggregates: + if agg_change.year is None: + agg_change.year = baseline_sim.dataset.year - # Process each item in the group - for idx, agg_change in enumerate(group): - item_start = time.time() - - # Infer entity if not provided + # Infer entity if not specified if agg_change.entity is None: - agg_change.entity = AggregateUtils.infer_entity( - agg_change.variable_name, - baseline_tables - ) - - # Compute filter mask on baseline - filter_mask = AggregateChange._get_filter_mask_from_tables( - agg_change, baseline_tables + agg_change.entity = baseline_engine.infer_entity(agg_change.variable_name) + + # Build filter specification + filters = None + if agg_change.filter_variable_name: + filters = { + 'variable': agg_change.filter_variable_name, + 'value': agg_change.filter_variable_value, + 'leq': agg_change.filter_variable_leq, + 'geq': agg_change.filter_variable_geq, + 'quantile_leq': agg_change.filter_variable_quantile_leq, + 'quantile_geq': agg_change.filter_variable_quantile_geq, + } + + # Get variable series with filters for both simulations + baseline_series = baseline_engine.get_variable_series( + agg_change.variable_name, + agg_change.entity, + filters ) - - # Compute baseline value - baseline_value = AggregateChange._compute_single_aggregate_from_tables( - agg_change, baseline_tables, filter_mask + comparison_series = comparison_engine.get_variable_series( + agg_change.variable_name, + agg_change.entity, + filters ) - # Compute comparison value using same filter - comparison_value = AggregateChange._compute_single_aggregate_from_tables( - agg_change, comparison_tables, filter_mask + # Compute aggregates + agg_change.baseline_value = baseline_engine.aggregate( + baseline_series, + agg_change.aggregate_function + ) + agg_change.comparison_value = comparison_engine.aggregate( + comparison_series, + agg_change.aggregate_function ) - # Compute changes - agg_change.baseline_value = baseline_value - agg_change.comparison_value = comparison_value - agg_change.change = comparison_value - baseline_value + # Calculate changes + agg_change.change = agg_change.comparison_value - agg_change.baseline_value - # Compute relative change (avoiding division by zero) - if baseline_value != 0: - agg_change.relative_change = (comparison_value - baseline_value) / abs(baseline_value) + if agg_change.baseline_value != 0: + agg_change.relative_change = ( + agg_change.change / abs(agg_change.baseline_value) + ) else: - agg_change.relative_change = None if comparison_value == 0 else float('inf') + agg_change.relative_change = ( + None if agg_change.comparison_value == 0 else float('inf') + ) results.append(agg_change) - group_time = time.time() - print(f"[PERFORMANCE] Batch processing took {group_time - group_start:.3f} seconds ({(group_time - group_start) / len(group):.3f}s per item)") - - total_time = time.time() - print(f"[PERFORMANCE] AggregateChange.run completed in {total_time - start_time:.2f} seconds") return results - - @staticmethod - def _get_filter_mask_from_tables( - agg_change: "AggregateChange", tables: dict - ) -> pd.Series | None: - """Get filter mask from pre-prepared tables, handling cross-entity filters.""" - if agg_change.filter_variable_name is None: - return None - - if agg_change.entity not in tables: - raise ValueError( - f"Entity {agg_change.entity} not found in simulation results" - ) - - # Find which entity contains the filter variable - filter_entity = None - for entity, table in tables.items(): - if agg_change.filter_variable_name in table.columns: - filter_entity = entity - break - - if filter_entity is None: - raise ValueError( - f"Filter variable {agg_change.filter_variable_name} not found in any entity" - ) - - # Get the dataframe for filtering - if filter_entity == agg_change.entity: - # Same entity - use directly - df = tables[agg_change.entity] - filter_series = df[agg_change.filter_variable_name] - else: - # Different entity - need to map filter variable to target entity - filter_df = tables[filter_entity] - mapped_filter = AggregateUtils.map_variable_across_entities( - filter_df, - agg_change.filter_variable_name, - filter_entity, - agg_change.entity, - tables - ) - df = tables[agg_change.entity] - filter_series = mapped_filter - - mask = pd.Series([True] * len(df), index=df.index) - - if agg_change.filter_variable_value is not None: - mask &= filter_series == agg_change.filter_variable_value - - if agg_change.filter_variable_leq is not None: - mask &= filter_series <= agg_change.filter_variable_leq - - if agg_change.filter_variable_geq is not None: - mask &= filter_series >= agg_change.filter_variable_geq - - if agg_change.filter_variable_quantile_leq is not None or agg_change.filter_variable_quantile_geq is not None: - if agg_change.filter_variable_quantile_leq is not None: - threshold = filter_series.quantile(agg_change.filter_variable_quantile_leq) - mask &= filter_series <= threshold - - if agg_change.filter_variable_quantile_geq is not None: - threshold = filter_series.quantile(agg_change.filter_variable_quantile_geq) - mask &= filter_series >= threshold - - return mask - - @staticmethod - def _compute_single_aggregate_from_tables( - agg_change: "AggregateChange", - tables: dict, - filter_mask: pd.Series | None = None - ) -> float: - """Compute aggregate value from pre-prepared tables.""" - if agg_change.entity not in tables: - raise ValueError( - f"Entity {agg_change.entity} not found in simulation results" - ) - - # Check if variable is in the target entity - target_entity = agg_change.entity - variable_entity = None - - # Find which entity contains the variable - for entity, table in tables.items(): - if agg_change.variable_name in table.columns: - variable_entity = entity - break - - if variable_entity is None: - raise ValueError( - f"Variable {agg_change.variable_name} not found in any entity" - ) - - # If variable is in a different entity than the filter, we need to map - if variable_entity != target_entity: - # Get the variable data from its native entity - source_table = tables[variable_entity] - - # Map it to the target entity level - try: - mapped_series = AggregateUtils.map_variable_across_entities( - source_table, - agg_change.variable_name, - variable_entity, - target_entity, - tables - ) - # Create a temporary dataframe with the mapped variable - table = tables[target_entity].copy() - table[agg_change.variable_name] = mapped_series - except ValueError as e: - # If mapping fails, raise informative error - raise ValueError( - f"Variable {agg_change.variable_name} is in {variable_entity} entity, " - f"but filters are at {target_entity} level. Cannot map between these entities: {str(e)}" - ) - else: - table = tables[agg_change.entity] - - df = table - - if filter_mask is not None: - df = df[filter_mask] - - if len(df) == 0: - return 0.0 - - # Use centralized compute function - return AggregateUtils.compute_aggregate( - df[agg_change.variable_name], - agg_change.aggregate_function - ) - diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py new file mode 100644 index 00000000..cc38040a --- /dev/null +++ b/tests/test_aggregate.py @@ -0,0 +1,442 @@ +""" +Tests for the clean aggregate implementation. + +Tests cover: +- Basic aggregations (sum, mean, median, count) +- Filtering (value, range, quantile) +- Cross-entity queries +- Batching efficiency +- Edge cases +""" + +import pytest +import pandas as pd + +from policyengine.models.aggregate import Aggregate, AggregateType + + +class MockSimulation: + """Mock simulation for testing.""" + + def __init__(self, result, year=2024): + self.result = result + self.dataset = MockDataset(year) + + +class MockDataset: + def __init__(self, year): + self.year = year + + +@pytest.fixture +def sample_tables(): + """Create sample person/household tables for testing.""" + person = pd.DataFrame({ + 'person_id': [0, 1, 2, 3], + 'person_household_id': [0, 0, 1, 1], + 'person_weight': [100.0, 100.0, 200.0, 200.0], + 'age': [30, 5, 45, 40], + 'employment_income': [50000, 0, 60000, 55000], + }) + + household = pd.DataFrame({ + 'household_id': [0, 1], + 'household_weight': [100.0, 200.0], + 'household_net_income': [50000, 115000], + 'is_in_poverty': [1, 0], + }) + + return {'person': person, 'household': household} + + +class TestBasicAggregations: + """Test basic aggregation functions.""" + + def test_sum(self, sample_tables): + """Test sum aggregation.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.SUM, + entity='person' + ) + results = Aggregate.run([agg]) + # Weighted sum: 50000*100 + 0*100 + 60000*200 + 55000*200 = 28,000,000 + assert results[0].value == 28_000_000.0 + + def test_mean(self, sample_tables): + """Test mean aggregation.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='age', + aggregate_function=AggregateType.MEAN, + entity='person' + ) + results = Aggregate.run([agg]) + # Weighted mean: (30*100 + 5*100 + 45*200 + 40*200) / 600 = 34.17 + assert round(results[0].value, 2) == 34.17 + + def test_count(self, sample_tables): + """Test count aggregation.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person' + ) + results = Aggregate.run([agg]) + assert results[0].value == 4.0 + + def test_median(self, sample_tables): + """Test median aggregation.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='age', + aggregate_function=AggregateType.MEDIAN, + entity='person' + ) + results = Aggregate.run([agg]) + assert results[0].value > 0 + + def test_entity_inference(self, sample_tables): + """Test that entity is inferred correctly.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.SUM + # entity not specified + ) + results = Aggregate.run([agg]) + assert results[0].entity == 'person' + assert results[0].value == 28_000_000.0 + + +class TestFiltering: + """Test filtering functionality.""" + + def test_value_filter(self, sample_tables): + """Test filtering with exact value match.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='age', + filter_variable_value=30 + ) + results = Aggregate.run([agg]) + assert results[0].value == 1.0 + + def test_range_filter_leq(self, sample_tables): + """Test filtering with <= operator.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.SUM, + entity='person', + filter_variable_name='age', + filter_variable_leq=35 + ) + results = Aggregate.run([agg]) + # Persons with age <= 35: person 0 (age 30) and person 1 (age 5) + # Weighted sum: 50000*100 + 0*100 = 5,000,000 + assert results[0].value == 5_000_000.0 + + def test_range_filter_geq(self, sample_tables): + """Test filtering with >= operator.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.SUM, + entity='person', + filter_variable_name='age', + filter_variable_geq=40 + ) + results = Aggregate.run([agg]) + # Persons with age >= 40: person 2 (age 45) and person 3 (age 40) + # Weighted sum: 60000*200 + 55000*200 = 23,000,000 + assert results[0].value == 23_000_000.0 + + def test_combined_range_filters(self, sample_tables): + """Test combining leq and geq filters.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.SUM, + entity='person', + filter_variable_name='age', + filter_variable_geq=18, + filter_variable_leq=35 + ) + results = Aggregate.run([agg]) + # Person 0: age 30, income 50000, weight 100 + assert results[0].value == 5_000_000.0 + + def test_quantile_filter_leq(self, sample_tables): + """Test filtering with quantile_leq.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='age', + filter_variable_quantile_leq=0.5 + ) + results = Aggregate.run([agg]) + # Bottom 50% by age should have at least 2 people + assert results[0].value >= 2.0 + + def test_quantile_filter_geq(self, sample_tables): + """Test filtering with quantile_geq.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='age', + filter_variable_quantile_geq=0.5 + ) + results = Aggregate.run([agg]) + # Top 50% by age should have at least 2 people + assert results[0].value >= 2.0 + + +class TestCrossEntity: + """Test cross-entity queries.""" + + def test_household_filter_on_person_aggregation(self, sample_tables): + """Test filtering persons by household variable.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='is_in_poverty', + filter_variable_value=1 + ) + results = Aggregate.run([agg]) + # Persons in poor households (household 0): persons 0 and 1 + assert results[0].value == 2.0 + + def test_person_to_household_aggregation(self, sample_tables): + """Test aggregating person variable at household level.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.SUM, + entity='household' + ) + results = Aggregate.run([agg]) + # Employment income summed to household level: 50000 + 115000 = 165,000 + assert results[0].value == 165_000.0 + + def test_poverty_rate_calculation(self, sample_tables): + """Test calculating poverty rate.""" + sim = MockSimulation(sample_tables) + + # Count persons in poverty + poor = Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='is_in_poverty', + filter_variable_value=1 + ) + + # Total persons + total = Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person' + ) + + results = Aggregate.run([poor, total]) + poverty_rate = results[0].value / results[1].value + assert poverty_rate == 0.5 # 2 out of 4 persons + + def test_mean_income_for_poor(self, sample_tables): + """Test mean income for persons in poor households.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.MEAN, + entity='person', + filter_variable_name='is_in_poverty', + filter_variable_value=1 + ) + results = Aggregate.run([agg]) + # Persons in poverty: person 0 (income 50000, weight 100), person 1 (income 0, weight 100) + # Weighted mean: (50000*100 + 0*100) / 200 = 25000 + assert results[0].value == 25000.0 + + +class TestBatching: + """Test batch processing efficiency.""" + + def test_batch_same_simulation(self, sample_tables): + """Test that aggregates with same simulation are batched.""" + sim = MockSimulation(sample_tables) + + aggregates = [ + Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.SUM, + entity='person' + ), + Aggregate( + simulation=sim, + variable_name='age', + aggregate_function=AggregateType.MEAN, + entity='person' + ), + Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person' + ), + ] + + results = Aggregate.run(aggregates) + assert len(results) == 3 + assert results[0].value == 28_000_000.0 + assert round(results[1].value, 2) == 34.17 + assert results[2].value == 4.0 + + def test_batch_different_filters(self, sample_tables): + """Test batching aggregates with different filters.""" + sim = MockSimulation(sample_tables) + + aggregates = [ + Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='age', + filter_variable_leq=17 + ), + Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='age', + filter_variable_geq=18 + ), + ] + + results = Aggregate.run(aggregates) + assert len(results) == 2 + assert results[0].value == 1.0 # Children + assert results[1].value == 3.0 # Adults + + +class TestEdgeCases: + """Test edge cases.""" + + def test_empty_result(self, sample_tables): + """Test filtering that results in empty set.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='age', + filter_variable_value=999 + ) + results = Aggregate.run([agg]) + assert results[0].value == 0.0 + + def test_weight_column_sum(self, sample_tables): + """Test that weight columns avoid double-weighting.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='person_weight', + aggregate_function=AggregateType.SUM, + entity='person' + ) + results = Aggregate.run([agg]) + # Simple sum (not weighted): 100 + 100 + 200 + 200 = 600 + assert results[0].value == 600.0 + + def test_missing_variable(self, sample_tables): + """Test error when variable doesn't exist.""" + sim = MockSimulation(sample_tables) + agg = Aggregate( + simulation=sim, + variable_name='nonexistent', + aggregate_function=AggregateType.SUM + ) + with pytest.raises(ValueError, match='not found'): + Aggregate.run([agg]) + + +class TestComplexScenarios: + """Test complex real-world scenarios.""" + + def test_poverty_by_age_group(self, sample_tables): + """Test poverty analysis by age group.""" + sim = MockSimulation(sample_tables) + + # Children in poverty + children_poor = Aggregate( + simulation=sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='age', + filter_variable_leq=17 + ) + + results = Aggregate.run([children_poor]) + assert results[0].value == 1.0 # Person 1 (age 5) + + def test_multiple_aggregations(self, sample_tables): + """Test running multiple different aggregations together.""" + sim = MockSimulation(sample_tables) + + aggs = [ + Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.SUM, + entity='person' + ), + Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.MEAN, + entity='person' + ), + Aggregate( + simulation=sim, + variable_name='employment_income', + aggregate_function=AggregateType.MEDIAN, + entity='person' + ), + ] + + results = Aggregate.run(aggs) + assert len(results) == 3 + assert results[0].value > results[1].value > 0 + assert results[2].value > 0 diff --git a/tests/test_aggregate_change.py b/tests/test_aggregate_change.py new file mode 100644 index 00000000..f0f529fa --- /dev/null +++ b/tests/test_aggregate_change.py @@ -0,0 +1,479 @@ +""" +Tests for the clean AggregateChange implementation. + +Tests cover: +- Basic change calculations +- Relative change calculations +- Cross-entity filters +- Batching multiple changes +- Edge cases +""" + +import pytest +import pandas as pd + +from policyengine.models.aggregate_change import AggregateChange +from policyengine.models.aggregate import AggregateType + + +class MockSimulation: + """Mock simulation for testing.""" + + def __init__(self, result, year=2024, sim_id=None): + self.result = result + self.dataset = MockDataset(year) + self.id = sim_id or "sim_123" + + +class MockDataset: + def __init__(self, year): + self.year = year + + +@pytest.fixture +def baseline_tables(): + """Baseline simulation tables.""" + person = pd.DataFrame({ + 'person_id': [0, 1, 2, 3], + 'person_household_id': [0, 0, 1, 1], + 'person_weight': [100.0, 100.0, 200.0, 200.0], + 'age': [30, 5, 45, 40], + 'employment_income': [50000, 0, 60000, 55000], + 'benefits': [5000, 2000, 0, 0], + }) + + household = pd.DataFrame({ + 'household_id': [0, 1], + 'household_weight': [100.0, 200.0], + 'household_net_income': [57000, 115000], + 'is_in_poverty': [1, 0], + }) + + return {'person': person, 'household': household} + + +@pytest.fixture +def comparison_tables(): + """Comparison simulation tables (with policy change).""" + person = pd.DataFrame({ + 'person_id': [0, 1, 2, 3], + 'person_household_id': [0, 0, 1, 1], + 'person_weight': [100.0, 100.0, 200.0, 200.0], + 'age': [30, 5, 45, 40], + 'employment_income': [50000, 0, 60000, 55000], + 'benefits': [8000, 3000, 1000, 1000], # Benefits increased + }) + + household = pd.DataFrame({ + 'household_id': [0, 1], + 'household_weight': [100.0, 200.0], + 'household_net_income': [61000, 117000], # Incomes increased + 'is_in_poverty': [0, 0], # Household 0 lifted out of poverty + }) + + return {'person': person, 'household': household} + + +class TestBasicChanges: + """Test basic change calculations.""" + + def test_simple_change(self, baseline_tables, comparison_tables): + """Test calculating a simple change in totals.""" + baseline_sim = MockSimulation(baseline_tables) + comparison_sim = MockSimulation(comparison_tables) + + agg_change = AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='benefits', + aggregate_function=AggregateType.SUM, + entity='person' + ) + + results = AggregateChange.run([agg_change]) + result = results[0] + + # Baseline weighted sum: 5000*100 + 2000*100 + 0*200 + 0*200 = 700,000 + assert result.baseline_value == 700_000.0 + + # Comparison weighted sum: 8000*100 + 3000*100 + 1000*200 + 1000*200 = 1,500,000 + assert result.comparison_value == 1_500_000.0 + + # Change: 1,500,000 - 700,000 = 800,000 + assert result.change == 800_000.0 + + # Relative change: 800,000 / 700,000 ≈ 1.14 + assert round(result.relative_change, 2) == 1.14 + + def test_mean_change(self, baseline_tables, comparison_tables): + """Test calculating change in mean values.""" + baseline_sim = MockSimulation(baseline_tables) + comparison_sim = MockSimulation(comparison_tables) + + agg_change = AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='benefits', + aggregate_function=AggregateType.MEAN, + entity='person' + ) + + results = AggregateChange.run([agg_change]) + result = results[0] + + # Baseline weighted mean: 700,000 / 600 = 1,166.67 + assert round(result.baseline_value, 2) == 1166.67 + + # Comparison weighted mean: 1,500,000 / 600 = 2,500 + assert result.comparison_value == 2500.0 + + # Change: 2,500 - 1,166.67 = 1,333.33 + assert round(result.change, 2) == 1333.33 + + def test_count_change(self, baseline_tables, comparison_tables): + """Test change in counts (e.g., poverty count).""" + baseline_sim = MockSimulation(baseline_tables) + comparison_sim = MockSimulation(comparison_tables) + + # Count households in poverty + agg_change = AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='household_id', + aggregate_function=AggregateType.COUNT, + entity='household', + filter_variable_name='is_in_poverty', + filter_variable_value=1 + ) + + results = AggregateChange.run([agg_change]) + result = results[0] + + # Baseline: 1 household in poverty + assert result.baseline_value == 1.0 + + # Comparison: 0 households in poverty + assert result.comparison_value == 0.0 + + # Change: -1 household + assert result.change == -1.0 + + +class TestCrossEntityChanges: + """Test changes with cross-entity filters.""" + + def test_persons_in_poverty_change(self, baseline_tables, comparison_tables): + """Test change in count of persons in poor households.""" + baseline_sim = MockSimulation(baseline_tables) + comparison_sim = MockSimulation(comparison_tables) + + # Count persons in poor households + agg_change = AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='is_in_poverty', + filter_variable_value=1 + ) + + results = AggregateChange.run([agg_change]) + result = results[0] + + # Baseline: 2 persons in poor households (persons 0, 1) + assert result.baseline_value == 2.0 + + # Comparison: 0 persons in poor households + assert result.comparison_value == 0.0 + + # Change: -2 persons + assert result.change == -2.0 + + def test_mean_benefits_for_poor(self, baseline_tables, comparison_tables): + """Test change in mean benefits for persons in poor households.""" + baseline_sim = MockSimulation(baseline_tables) + comparison_sim = MockSimulation(comparison_tables) + + agg_change = AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='benefits', + aggregate_function=AggregateType.MEAN, + entity='person', + filter_variable_name='is_in_poverty', + filter_variable_value=1 + ) + + results = AggregateChange.run([agg_change]) + result = results[0] + + # Baseline: persons 0 and 1 in poverty + # Weighted mean: (5000*100 + 2000*100) / 200 = 3,500 + assert result.baseline_value == 3500.0 + + # Comparison: 0 persons in poverty (empty filter) + assert result.comparison_value == 0.0 + + +class TestBatching: + """Test efficient batching of multiple changes.""" + + def test_batch_multiple_changes(self, baseline_tables, comparison_tables): + """Test processing multiple aggregate changes efficiently.""" + baseline_sim = MockSimulation(baseline_tables) + comparison_sim = MockSimulation(comparison_tables) + + changes = [ + AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='benefits', + aggregate_function=AggregateType.SUM, + entity='person' + ), + AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='employment_income', + aggregate_function=AggregateType.MEAN, + entity='person' + ), + AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='is_in_poverty', + filter_variable_value=1 + ), + ] + + results = AggregateChange.run(changes) + + assert len(results) == 3 + assert results[0].change == 800_000.0 # Benefits increased + assert results[1].change == 0.0 # Employment income unchanged + assert results[2].change == -2.0 # Poverty count decreased + + +class TestRangeFilters: + """Test aggregate changes with range filters.""" + + def test_change_with_age_filter(self, baseline_tables, comparison_tables): + """Test change in benefits for specific age group.""" + baseline_sim = MockSimulation(baseline_tables) + comparison_sim = MockSimulation(comparison_tables) + + # Benefits for children (age < 18) + agg_change = AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='benefits', + aggregate_function=AggregateType.SUM, + entity='person', + filter_variable_name='age', + filter_variable_leq=17 + ) + + results = AggregateChange.run([agg_change]) + result = results[0] + + # Person 1 (age 5): baseline 2000*100, comparison 3000*100 + assert result.baseline_value == 200_000.0 + assert result.comparison_value == 300_000.0 + assert result.change == 100_000.0 + + def test_change_with_quantile_filter(self, baseline_tables, comparison_tables): + """Test change for income quantiles.""" + baseline_sim = MockSimulation(baseline_tables) + comparison_sim = MockSimulation(comparison_tables) + + # Benefits for bottom 50% by income + agg_change = AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='benefits', + aggregate_function=AggregateType.MEAN, + entity='person', + filter_variable_name='employment_income', + filter_variable_quantile_leq=0.5 + ) + + results = AggregateChange.run([agg_change]) + result = results[0] + + # Should get results for lower-income persons + assert result.baseline_value >= 0 + assert result.comparison_value >= 0 + + +class TestEdgeCases: + """Test edge cases.""" + + def test_zero_baseline_value(self, baseline_tables, comparison_tables): + """Test relative change when baseline is zero.""" + # Create tables where baseline has zero value + baseline_zero = { + 'person': pd.DataFrame({ + 'person_id': [0, 1], + 'person_weight': [1.0, 1.0], + 'new_benefit': [0, 0] + }) + } + + comparison_nonzero = { + 'person': pd.DataFrame({ + 'person_id': [0, 1], + 'person_weight': [1.0, 1.0], + 'new_benefit': [1000, 1000] + }) + } + + baseline_sim = MockSimulation(baseline_zero) + comparison_sim = MockSimulation(comparison_nonzero) + + agg_change = AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='new_benefit', + aggregate_function=AggregateType.SUM, + entity='person' + ) + + results = AggregateChange.run([agg_change]) + result = results[0] + + assert result.baseline_value == 0.0 + assert result.comparison_value == 2000.0 + assert result.change == 2000.0 + assert result.relative_change == float('inf') + + def test_both_zero(self): + """Test when both baseline and comparison are zero.""" + baseline = { + 'person': pd.DataFrame({ + 'person_id': [0, 1], + 'person_weight': [1.0, 1.0], + 'value': [0, 0] + }) + } + + baseline_sim = MockSimulation(baseline) + comparison_sim = MockSimulation(baseline) + + agg_change = AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='value', + aggregate_function=AggregateType.SUM, + entity='person' + ) + + results = AggregateChange.run([agg_change]) + result = results[0] + + assert result.baseline_value == 0.0 + assert result.comparison_value == 0.0 + assert result.change == 0.0 + assert result.relative_change is None + + def test_missing_simulation(self): + """Test error when simulation is missing.""" + agg_change = AggregateChange( + variable_name='value', + aggregate_function=AggregateType.SUM + ) + + with pytest.raises(ValueError, match='missing baseline_simulation'): + AggregateChange.run([agg_change]) + + def test_negative_change(self, baseline_tables, comparison_tables): + """Test calculating negative changes correctly.""" + # Create scenario where value decreases + baseline_high = { + 'person': pd.DataFrame({ + 'person_id': [0, 1], + 'person_weight': [1.0, 1.0], + 'value': [1000, 1000] + }) + } + + comparison_low = { + 'person': pd.DataFrame({ + 'person_id': [0, 1], + 'person_weight': [1.0, 1.0], + 'value': [500, 500] + }) + } + + baseline_sim = MockSimulation(baseline_high) + comparison_sim = MockSimulation(comparison_low) + + agg_change = AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='value', + aggregate_function=AggregateType.SUM, + entity='person' + ) + + results = AggregateChange.run([agg_change]) + result = results[0] + + assert result.baseline_value == 2000.0 + assert result.comparison_value == 1000.0 + assert result.change == -1000.0 + assert result.relative_change == -0.5 + + +class TestRealWorldScenarios: + """Test realistic policy analysis scenarios.""" + + def test_poverty_impact_analysis(self, baseline_tables, comparison_tables): + """Test complete poverty impact analysis.""" + baseline_sim = MockSimulation(baseline_tables) + comparison_sim = MockSimulation(comparison_tables) + + analysis = [ + # Total poverty count + AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='household_id', + aggregate_function=AggregateType.COUNT, + entity='household', + filter_variable_name='is_in_poverty', + filter_variable_value=1 + ), + # Persons in poverty + AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='person_id', + aggregate_function=AggregateType.COUNT, + entity='person', + filter_variable_name='is_in_poverty', + filter_variable_value=1 + ), + # Mean benefits for poor households + AggregateChange( + baseline_simulation=baseline_sim, + comparison_simulation=comparison_sim, + variable_name='benefits', + aggregate_function=AggregateType.MEAN, + entity='person', + filter_variable_name='is_in_poverty', + filter_variable_value=1 + ), + ] + + results = AggregateChange.run(analysis) + + # Poverty decreased + assert results[0].change < 0 # Fewer poor households + assert results[1].change < 0 # Fewer poor persons + + # Benefits increased for those who were poor + assert results[2].baseline_value > 0 diff --git a/tests/test_aggregate_utils.py b/tests/test_aggregate_utils.py deleted file mode 100644 index 39090dc6..00000000 --- a/tests/test_aggregate_utils.py +++ /dev/null @@ -1,441 +0,0 @@ -""" -Unit tests for AggregateUtils cross-entity mapping functionality. - -These tests verify that variables can be correctly mapped between entities -(person <-> household, person <-> tax_unit, etc.) while preserving weights. -""" - -import pytest -import pandas as pd -from microdf import MicroDataFrame -from policyengine.models.aggregate import AggregateUtils, AggregateType - - -class MockSimulation: - """Mock simulation for testing.""" - def __init__(self, result): - self.result = result - - -@pytest.fixture -def sample_person_household_tables(): - """ - Create sample tables with person and household entities. - - Structure: - - 2 households (ids 0, 1) - - 4 persons (ids 0, 1, 2, 3) - - Persons 0, 1 belong to household 0 - - Persons 2, 3 belong to household 1 - """ - person_table = pd.DataFrame({ - 'person_id': [0, 1, 2, 3], - 'person_household_id': [0, 0, 1, 1], - 'person_weight': [1.0, 1.0, 1.0, 1.0], - 'age': [30, 5, 45, 40], - 'employment_income': [50000, 0, 60000, 55000], - }) - - household_table = pd.DataFrame({ - 'household_id': [0, 1], - 'household_weight': [1.0, 1.0], - 'household_net_income': [50000, 115000], - 'is_in_poverty': [1, 0], # household 0 is in poverty, household 1 is not - }) - - return { - 'person': person_table, - 'household': household_table - } - - -@pytest.fixture -def weighted_person_household_tables(): - """ - Create sample tables with different weights to test weight preservation. - - Structure: - - 2 households with different weights - - 4 persons with weights matching their households - """ - person_table = pd.DataFrame({ - 'person_id': [0, 1, 2, 3], - 'person_household_id': [0, 0, 1, 1], - 'person_weight': [100.0, 100.0, 200.0, 200.0], # Different weights - 'age': [30, 5, 45, 40], - 'employment_income': [50000, 0, 60000, 55000], - }) - - household_table = pd.DataFrame({ - 'household_id': [0, 1], - 'household_weight': [100.0, 200.0], # Different weights - 'household_net_income': [50000, 115000], - 'is_in_poverty': [1, 0], - }) - - return { - 'person': person_table, - 'household': household_table - } - - -class TestPrepareTablesWithWeights: - """Test that prepare_tables correctly creates MicroDataFrames with weights.""" - - def test_prepare_tables_creates_microdataframes(self, sample_person_household_tables): - """Test that tables with weight columns become MicroDataFrames.""" - mock_sim = MockSimulation(sample_person_household_tables) - tables = AggregateUtils.prepare_tables(mock_sim) - - # Both tables should be MicroDataFrames - assert isinstance(tables['person'], MicroDataFrame) - assert isinstance(tables['household'], MicroDataFrame) - - # Check that weights are set correctly - assert tables['person'].weights_col == 'person_weight' - assert tables['household'].weights_col == 'household_weight' - - def test_prepare_tables_without_weights(self): - """Test that tables without weight columns remain regular DataFrames.""" - tables_without_weights = { - 'person': pd.DataFrame({ - 'person_id': [0, 1], - 'age': [30, 40] - }) - } - mock_sim = MockSimulation(tables_without_weights) - tables = AggregateUtils.prepare_tables(mock_sim) - - # Should be regular DataFrame, not MicroDataFrame - assert isinstance(tables['person'], pd.DataFrame) - assert not isinstance(tables['person'], MicroDataFrame) - - -class TestInferEntity: - """Test entity inference from variable names.""" - - def test_infer_entity_person_variable(self, sample_person_household_tables): - """Test inferring entity for a person-level variable.""" - entity = AggregateUtils.infer_entity('age', sample_person_household_tables) - assert entity == 'person' - - def test_infer_entity_household_variable(self, sample_person_household_tables): - """Test inferring entity for a household-level variable.""" - entity = AggregateUtils.infer_entity('is_in_poverty', sample_person_household_tables) - assert entity == 'household' - - def test_infer_entity_nonexistent_variable(self, sample_person_household_tables): - """Test that nonexistent variable raises ValueError.""" - with pytest.raises(ValueError, match="Variable nonexistent not found"): - AggregateUtils.infer_entity('nonexistent', sample_person_household_tables) - - -class TestMapVariableAcrossEntities: - """Test cross-entity variable mapping.""" - - def test_map_household_to_person(self, sample_person_household_tables): - """ - Test mapping a household variable to person level. - Each person should get their household's value. - """ - mapped = AggregateUtils.map_variable_across_entities( - sample_person_household_tables['household'], - 'is_in_poverty', - 'household', - 'person', - sample_person_household_tables - ) - - # Persons 0, 1 belong to household 0 (in poverty) - assert mapped.iloc[0] == 1 - assert mapped.iloc[1] == 1 - - # Persons 2, 3 belong to household 1 (not in poverty) - assert mapped.iloc[2] == 0 - assert mapped.iloc[3] == 0 - - def test_map_person_to_household(self, sample_person_household_tables): - """ - Test mapping a person variable to household level. - Should sum persons' values within each household. - """ - mapped = AggregateUtils.map_variable_across_entities( - sample_person_household_tables['person'], - 'employment_income', - 'person', - 'household', - sample_person_household_tables - ) - - # Household 0: persons 0 (50000) + 1 (0) = 50000 - assert mapped.iloc[0] == 50000 - - # Household 1: persons 2 (60000) + 3 (55000) = 115000 - assert mapped.iloc[1] == 115000 - - def test_map_same_entity(self, sample_person_household_tables): - """Test that mapping to same entity returns the variable as-is.""" - mapped = AggregateUtils.map_variable_across_entities( - sample_person_household_tables['person'], - 'age', - 'person', - 'person', - sample_person_household_tables - ) - - pd.testing.assert_series_equal( - mapped, - sample_person_household_tables['person']['age'], - check_names=False - ) - - def test_map_preserves_length(self, sample_person_household_tables): - """Test that mapped series has correct length for target entity.""" - # Household to person: should have 4 entries (4 persons) - mapped_h_to_p = AggregateUtils.map_variable_across_entities( - sample_person_household_tables['household'], - 'is_in_poverty', - 'household', - 'person', - sample_person_household_tables - ) - assert len(mapped_h_to_p) == 4 - - # Person to household: should have 2 entries (2 households) - mapped_p_to_h = AggregateUtils.map_variable_across_entities( - sample_person_household_tables['person'], - 'employment_income', - 'person', - 'household', - sample_person_household_tables - ) - assert len(mapped_p_to_h) == 2 - - -class TestComputeAggregate: - """Test aggregate computation functions.""" - - def test_sum_simple(self): - """Test simple sum aggregation.""" - series = pd.Series([10, 20, 30, 40]) - result = AggregateUtils.compute_aggregate(series, AggregateType.SUM) - assert result == 100.0 - - def test_sum_weighted(self): - """Test weighted sum using MicroDataFrame.""" - df = pd.DataFrame({ - 'value': [10, 20, 30], - 'weight': [1.0, 2.0, 1.0] - }) - mdf = MicroDataFrame(df, weights='weight') - result = AggregateUtils.compute_aggregate(mdf['value'], AggregateType.SUM) - - # Weighted sum: 10*1 + 20*2 + 30*1 = 80 - assert result == 80.0 - - def test_mean_simple(self): - """Test simple mean aggregation.""" - series = pd.Series([10, 20, 30, 40]) - result = AggregateUtils.compute_aggregate(series, AggregateType.MEAN) - assert result == 25.0 - - def test_mean_weighted(self): - """Test weighted mean using MicroDataFrame.""" - df = pd.DataFrame({ - 'value': [10, 20, 30], - 'weight': [1.0, 2.0, 1.0] - }) - mdf = MicroDataFrame(df, weights='weight') - result = AggregateUtils.compute_aggregate(mdf['value'], AggregateType.MEAN) - - # Weighted mean: (10*1 + 20*2 + 30*1) / (1+2+1) = 80/4 = 20 - assert result == 20.0 - - def test_median_simple(self): - """Test simple median aggregation.""" - series = pd.Series([10, 20, 30, 40, 50]) - result = AggregateUtils.compute_aggregate(series, AggregateType.MEDIAN) - assert result == 30.0 - - def test_count(self): - """Test count aggregation (counts all entries).""" - series = pd.Series([0, 10, 0, 20, 30, 0]) - result = AggregateUtils.compute_aggregate(series, AggregateType.COUNT) - # COUNT returns the total number of entries, not just non-zero - assert result == 6.0 - - # To count only non-zero values, filter first then count - non_zero = series[series > 0] - result_filtered = AggregateUtils.compute_aggregate(non_zero, AggregateType.COUNT) - assert result_filtered == 3.0 - - def test_empty_series(self): - """Test that empty series returns 0.""" - series = pd.Series([]) - for agg_type in [AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT]: - result = AggregateUtils.compute_aggregate(series, agg_type) - assert result == 0.0 - - -class TestPovertyRateScenario: - """ - Test the specific poverty rate scenario that was giving 1% result. - - This tests the complete flow: prepare tables, map variables, apply filters, - and compute aggregates with weights. - """ - - def test_poverty_rate_with_household_filter_person_aggregation(self, weighted_person_household_tables): - """ - Test computing poverty rate at person level with household-level filter. - - Scenario: - - Filter: households in poverty (is_in_poverty == 1) - - Variable: count of persons - - This should count persons in poor households - """ - # Prepare tables as they would be in production - mock_sim = MockSimulation(weighted_person_household_tables) - tables = AggregateUtils.prepare_tables(mock_sim) - - # Step 1: Get household filter variable and map to person level - household_df = tables['household'] - filter_variable = 'is_in_poverty' - - # Map household filter to person level - mapped_filter = AggregateUtils.map_variable_across_entities( - household_df, - filter_variable, - 'household', - 'person', - tables - ) - - # Build filter mask at person level - person_table = tables['person'] - mask = mapped_filter == 1 - - # Step 2: Filter the person table - filtered_table = person_table[mask] - - # Step 3: Count persons (weighted) - # Persons 0 and 1 are in poor households, each with weight 100 - count = AggregateUtils.compute_aggregate( - filtered_table['person_id'], - AggregateType.COUNT - ) - - # Should count 2 persons (weighted count with MicroDataFrame should be 200) - # But COUNT just counts entries > 0, not weighted - assert count == 2.0 - - # For weighted sum of persons in poverty: - sum_weights = AggregateUtils.compute_aggregate( - filtered_table['person_weight'], - AggregateType.SUM - ) - assert sum_weights == 200.0 # 100 + 100 - - def test_poverty_rate_household_level(self, weighted_person_household_tables): - """ - Test computing poverty rate at household level. - - This is more straightforward - just filter households and count. - """ - mock_sim = MockSimulation(weighted_person_household_tables) - tables = AggregateUtils.prepare_tables(mock_sim) - - household_table = tables['household'] - - # Filter to households in poverty - mask = household_table['is_in_poverty'] == 1 - filtered = household_table[mask] - - # Count households - count = AggregateUtils.compute_aggregate( - filtered['is_in_poverty'], - AggregateType.COUNT - ) - assert count == 1.0 # Only 1 household in poverty - - # Weighted sum - sum_weights = AggregateUtils.compute_aggregate( - filtered['household_weight'], - AggregateType.SUM - ) - assert sum_weights == 100.0 # Weight of household 0 - - def test_mean_income_in_poor_households(self, weighted_person_household_tables): - """ - Test computing mean income for persons in poor households. - - This tests the complete cross-entity flow with weights. - """ - mock_sim = MockSimulation(weighted_person_household_tables) - tables = AggregateUtils.prepare_tables(mock_sim) - - # Step 1: Map household poverty status to person level - mapped_poverty = AggregateUtils.map_variable_across_entities( - tables['household'], - 'is_in_poverty', - 'household', - 'person', - tables - ) - - # Step 2: Filter persons in poor households - person_table = tables['person'] - mask = mapped_poverty == 1 - filtered_table = person_table[mask] - - # Step 3: Compute mean employment income - mean_income = AggregateUtils.compute_aggregate( - filtered_table['employment_income'], - AggregateType.MEAN - ) - - # Persons 0 (income 50000, weight 100) and 1 (income 0, weight 100) - # Weighted mean: (50000*100 + 0*100) / (100+100) = 25000 - assert mean_income == 25000.0 - - -class TestEdgeCases: - """Test edge cases and error handling.""" - - def test_missing_link_column(self): - """Test error when link column is missing.""" - tables = { - 'person': pd.DataFrame({'age': [30, 40]}), - 'household': pd.DataFrame({'income': [50000, 60000]}) - } - - with pytest.raises(ValueError, match="Link column .* not found"): - AggregateUtils.map_variable_across_entities( - tables['household'], - 'income', - 'household', - 'person', - tables - ) - - def test_unknown_aggregate_function(self): - """Test error with unknown aggregate function.""" - series = pd.Series([10, 20, 30]) - - with pytest.raises(ValueError, match="Unknown aggregate function"): - AggregateUtils.compute_aggregate(series, 'unknown_function') - - def test_map_with_missing_entity(self): - """Test error when entity doesn't exist in tables.""" - tables = { - 'person': pd.DataFrame({'age': [30, 40]}) - } - - with pytest.raises(ValueError, match="No known link"): - AggregateUtils.map_variable_across_entities( - tables['person'], - 'age', - 'person', - 'nonexistent_entity', - tables - ) From 390fe6a5ef413d482d2fa9e9763d2e27e04f5a13 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Thu, 23 Oct 2025 15:52:38 +0100 Subject: [PATCH 12/35] Update --- src/policyengine/models/aggregate.py | 31 +++++++++++++++-------- tests/test_aggregate.py | 37 +++++++++++++++++----------- tests/test_aggregate_change.py | 18 +++++++------- uv.lock | 16 ++++++++---- 4 files changed, 64 insertions(+), 38 deletions(-) diff --git a/src/policyengine/models/aggregate.py b/src/policyengine/models/aggregate.py index 166bcfae..003b9aa4 100644 --- a/src/policyengine/models/aggregate.py +++ b/src/policyengine/models/aggregate.py @@ -3,7 +3,7 @@ from uuid import uuid4 import pandas as pd -from microdf import MicroDataFrame +from microdf import MicroDataFrame, MicroSeries from pydantic import BaseModel, ConfigDict, Field, SkipValidation if TYPE_CHECKING: @@ -148,11 +148,16 @@ def _map_variable( group_values = series.values person_group_ids = person_table[link_col].values - return pd.Series( - [group_values[int(gid)] if int(gid) < len(group_values) else 0 - for gid in person_group_ids], - index=person_table.index - ) + mapped_values = [ + group_values[int(gid)] if int(gid) < len(group_values) else 0 + for gid in person_group_ids + ] + + # Return MicroSeries with person weights + weight_col = f"{target_entity}_weight" + if isinstance(person_table, MicroDataFrame) and weight_col in person_table.columns: + return MicroSeries(mapped_values, weights=person_table[weight_col]) + return pd.Series(mapped_values, index=person_table.index) # Person to group: aggregate up elif source_entity == "person" and target_entity != "person": @@ -170,10 +175,13 @@ def _map_variable( }).groupby(link_col)['value'].sum() target_table = self.tables[target_entity] - return pd.Series( - [grouped.get(i, 0) for i in range(len(target_table))], - index=target_table.index - ) + mapped_values = [grouped.get(i, 0) for i in range(len(target_table))] + + # Return MicroSeries with target entity weights + weight_col = f"{target_entity}_weight" + if isinstance(target_table, MicroDataFrame) and weight_col in target_table.columns: + return MicroSeries(mapped_values, weights=target_table[weight_col]) + return pd.Series(mapped_values, index=target_table.index) # Group to group: via person else: @@ -202,6 +210,9 @@ def aggregate(series: pd.Series, function: AggregateType) -> float: elif function == AggregateType.MEDIAN: return float(series.median()) elif function == AggregateType.COUNT: + # For MicroSeries, sum the weights to get weighted population count + if isinstance(series, MicroSeries): + return float(series.weights.sum()) return float(len(series)) else: raise ValueError(f"Unknown aggregate function: {function}") diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index cc38040a..d956e1c8 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -88,7 +88,8 @@ def test_count(self, sample_tables): entity='person' ) results = Aggregate.run([agg]) - assert results[0].value == 4.0 + # Weighted count: sum of person weights = 100 + 100 + 200 + 200 = 600 + assert results[0].value == 600.0 def test_median(self, sample_tables): """Test median aggregation.""" @@ -131,7 +132,8 @@ def test_value_filter(self, sample_tables): filter_variable_value=30 ) results = Aggregate.run([agg]) - assert results[0].value == 1.0 + # Weighted count: person 0 has age 30 and weight 100 + assert results[0].value == 100.0 def test_range_filter_leq(self, sample_tables): """Test filtering with <= operator.""" @@ -193,8 +195,9 @@ def test_quantile_filter_leq(self, sample_tables): filter_variable_quantile_leq=0.5 ) results = Aggregate.run([agg]) - # Bottom 50% by age should have at least 2 people - assert results[0].value >= 2.0 + # Weighted median age is 40, so includes ages <= 40: persons 0, 1, 3 + # Weighted count: 100 + 100 + 200 = 400 + assert results[0].value == 400.0 def test_quantile_filter_geq(self, sample_tables): """Test filtering with quantile_geq.""" @@ -208,8 +211,9 @@ def test_quantile_filter_geq(self, sample_tables): filter_variable_quantile_geq=0.5 ) results = Aggregate.run([agg]) - # Top 50% by age should have at least 2 people - assert results[0].value >= 2.0 + # Weighted median age is 40, so includes ages >= 40: persons 2, 3 + # Weighted count: 200 + 200 = 400 + assert results[0].value == 400.0 class TestCrossEntity: @@ -228,7 +232,8 @@ def test_household_filter_on_person_aggregation(self, sample_tables): ) results = Aggregate.run([agg]) # Persons in poor households (household 0): persons 0 and 1 - assert results[0].value == 2.0 + # Weighted count: 100 + 100 = 200 + assert results[0].value == 200.0 def test_person_to_household_aggregation(self, sample_tables): """Test aggregating person variable at household level.""" @@ -240,8 +245,11 @@ def test_person_to_household_aggregation(self, sample_tables): entity='household' ) results = Aggregate.run([agg]) - # Employment income summed to household level: 50000 + 115000 = 165,000 - assert results[0].value == 165_000.0 + # Employment income summed to household level with household weights: + # Household 0: (50000 + 0) * 100 = 5,000,000 + # Household 1: (60000 + 55000) * 200 = 23,000,000 + # Total weighted sum: 28,000,000 + assert results[0].value == 28_000_000.0 def test_poverty_rate_calculation(self, sample_tables): """Test calculating poverty rate.""" @@ -267,7 +275,8 @@ def test_poverty_rate_calculation(self, sample_tables): results = Aggregate.run([poor, total]) poverty_rate = results[0].value / results[1].value - assert poverty_rate == 0.5 # 2 out of 4 persons + # Weighted: 200 poor / 600 total = 1/3 + assert round(poverty_rate, 3) == 0.333 def test_mean_income_for_poor(self, sample_tables): """Test mean income for persons in poor households.""" @@ -318,7 +327,7 @@ def test_batch_same_simulation(self, sample_tables): assert len(results) == 3 assert results[0].value == 28_000_000.0 assert round(results[1].value, 2) == 34.17 - assert results[2].value == 4.0 + assert results[2].value == 600.0 # Weighted count def test_batch_different_filters(self, sample_tables): """Test batching aggregates with different filters.""" @@ -345,8 +354,8 @@ def test_batch_different_filters(self, sample_tables): results = Aggregate.run(aggregates) assert len(results) == 2 - assert results[0].value == 1.0 # Children - assert results[1].value == 3.0 # Adults + assert results[0].value == 100.0 # Children: person 1 weight 100 + assert results[1].value == 500.0 # Adults: persons 0,2,3 weights 100+200+200 class TestEdgeCases: @@ -409,7 +418,7 @@ def test_poverty_by_age_group(self, sample_tables): ) results = Aggregate.run([children_poor]) - assert results[0].value == 1.0 # Person 1 (age 5) + assert results[0].value == 100.0 # Person 1 (age 5, weight 100) def test_multiple_aggregations(self, sample_tables): """Test running multiple different aggregations together.""" diff --git a/tests/test_aggregate_change.py b/tests/test_aggregate_change.py index f0f529fa..11938a5d 100644 --- a/tests/test_aggregate_change.py +++ b/tests/test_aggregate_change.py @@ -149,14 +149,14 @@ def test_count_change(self, baseline_tables, comparison_tables): results = AggregateChange.run([agg_change]) result = results[0] - # Baseline: 1 household in poverty - assert result.baseline_value == 1.0 + # Baseline: household 0 in poverty with weight 100 + assert result.baseline_value == 100.0 # Comparison: 0 households in poverty assert result.comparison_value == 0.0 - # Change: -1 household - assert result.change == -1.0 + # Change: -100 (weighted household count) + assert result.change == -100.0 class TestCrossEntityChanges: @@ -181,14 +181,14 @@ def test_persons_in_poverty_change(self, baseline_tables, comparison_tables): results = AggregateChange.run([agg_change]) result = results[0] - # Baseline: 2 persons in poor households (persons 0, 1) - assert result.baseline_value == 2.0 + # Baseline: persons 0, 1 in poor households with weights 100 + 100 = 200 + assert result.baseline_value == 200.0 # Comparison: 0 persons in poor households assert result.comparison_value == 0.0 - # Change: -2 persons - assert result.change == -2.0 + # Change: -200 (weighted person count) + assert result.change == -200.0 def test_mean_benefits_for_poor(self, baseline_tables, comparison_tables): """Test change in mean benefits for persons in poor households.""" @@ -255,7 +255,7 @@ def test_batch_multiple_changes(self, baseline_tables, comparison_tables): assert len(results) == 3 assert results[0].change == 800_000.0 # Benefits increased assert results[1].change == 0.0 # Employment income unchanged - assert results[2].change == -2.0 # Poverty count decreased + assert results[2].change == -200.0 # Poverty count decreased (weighted) class TestRangeFilters: diff --git a/uv.lock b/uv.lock index c7bf4216..c8d1e963 100644 --- a/uv.lock +++ b/uv.lock @@ -1284,7 +1284,7 @@ wheels = [ [[package]] name = "policyengine" -version = "1.0.0" +version = "3.0.0" source = { editable = "." } dependencies = [ { name = "alembic" }, @@ -1311,6 +1311,9 @@ dev = [ { name = "furo" }, { name = "itables" }, { name = "jupyter-book" }, + { name = "policyengine-core" }, + { name = "policyengine-uk" }, + { name = "policyengine-us" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "ruff" }, @@ -1340,9 +1343,12 @@ requires-dist = [ { name = "jupyter-book", marker = "extra == 'dev'" }, { name = "microdf-python" }, { name = "pandas", specifier = ">=2.0.0" }, + { name = "policyengine-core", marker = "extra == 'dev'", specifier = ">=3.10" }, { name = "policyengine-core", marker = "extra == 'uk'", specifier = ">=3.10" }, { name = "policyengine-core", marker = "extra == 'us'", specifier = ">=3.10" }, - { name = "policyengine-uk", marker = "extra == 'uk'" }, + { name = "policyengine-uk", marker = "extra == 'dev'", specifier = ">=2.51.0" }, + { name = "policyengine-uk", marker = "extra == 'uk'", specifier = ">=2.51.0" }, + { name = "policyengine-us", marker = "extra == 'dev'", specifier = ">=1.213.1" }, { name = "policyengine-us", marker = "extra == 'us'", specifier = ">=1.213.1" }, { name = "psycopg2-binary", specifier = ">=2.9.0" }, { name = "pydantic", specifier = ">=2.0.0" }, @@ -1387,7 +1393,7 @@ wheels = [ [[package]] name = "policyengine-uk" -version = "2.50.0" +version = "2.55.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "microdf-python" }, @@ -1395,9 +1401,9 @@ dependencies = [ { name = "pydantic" }, { name = "tables" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/68/52/ad3abc8265b424238a545a4e1f95e2d7e4f3511ea3fa02ad1eca53df7857/policyengine_uk-2.50.0.tar.gz", hash = "sha256:f6ec9b8abce7995b48db70c700bd5096abea60e2bac03ec1cbffaafe1a93f6a8", size = 1048546, upload-time = "2025-09-03T09:16:54.493Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/af/d796c74d16536e072fa1cd5fb2ab85d66d9c62610db631a548d5161a6cca/policyengine_uk-2.55.3.tar.gz", hash = "sha256:28a2e3c9f63cd89bce4ddaded6861f75e6116863c9bad32731b77d0e9731e27c", size = 1051059, upload-time = "2025-10-22T10:04:54.355Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/78/9e/e2a3089864d636a7c1344b06d72219f66b02791272f86bbdfca23c511f6e/policyengine_uk-2.50.0-py3-none-any.whl", hash = "sha256:d099d14eda66ea8872a6e2a41110e8da298818dbe9e573a75ab5cdd65d85bf03", size = 1605337, upload-time = "2025-09-03T09:16:52.616Z" }, + { url = "https://files.pythonhosted.org/packages/11/bf/b64aeb51d68d3a80ee9b5e7c35e978fa1a6b16278816a8d370ce7d8623fc/policyengine_uk-2.55.3-py3-none-any.whl", hash = "sha256:13e29e10e6b45b278fb894474991ffd0560c54e1e8e45571560d026aec1dcf01", size = 1611162, upload-time = "2025-10-22T10:04:52.352Z" }, ] [[package]] From 2f7a628807519bcfb25728a2a572b31af7d881a6 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sat, 8 Nov 2025 11:50:33 +0000 Subject: [PATCH 13/35] Go back a bit --- .../lib/policyengine/__init__.py => CLAUDE.md | 0 build/lib/policyengine/database/__init__.py | 68 --- build/lib/policyengine/database/aggregate.py | 101 ---- .../policyengine/database/aggregate_change.py | 122 ----- .../baseline_parameter_value_table.py | 112 ---- .../database/baseline_variable_table.py | 81 --- build/lib/policyengine/database/database.py | 301 ----------- .../policyengine/database/dataset_table.py | 94 ---- .../policyengine/database/dynamic_table.py | 68 --- build/lib/policyengine/database/link.py | 8 - .../lib/policyengine/database/model_table.py | 60 --- .../database/model_version_table.py | 73 --- .../policyengine/database/parameter_table.py | 92 ---- .../database/parameter_value_table.py | 108 ---- .../lib/policyengine/database/policy_table.py | 136 ----- .../database/report_element_table.py | 106 ---- .../lib/policyengine/database/report_table.py | 120 ----- .../policyengine/database/simulation_table.py | 225 -------- .../lib/policyengine/database/table_mixin.py | 80 --- build/lib/policyengine/database/user_table.py | 57 --- .../database/versioned_dataset_table.py | 45 -- build/lib/policyengine/models/__init__.py | 39 -- build/lib/policyengine/models/aggregate.py | 132 ----- .../policyengine/models/aggregate_change.py | 143 ------ .../models/baseline_parameter_value.py | 16 - .../policyengine/models/baseline_variable.py | 12 - build/lib/policyengine/models/dataset.py | 18 - build/lib/policyengine/models/dynamic.py | 15 - build/lib/policyengine/models/model.py | 126 ----- build/lib/policyengine/models/parameter.py | 14 - .../policyengine/models/parameter_value.py | 14 - build/lib/policyengine/models/policy.py | 17 - .../policyengine/models/policyengine_uk.py | 113 ----- .../policyengine/models/policyengine_us.py | 115 ----- build/lib/policyengine/models/report.py | 20 - .../lib/policyengine/models/report_element.py | 38 -- build/lib/policyengine/models/simulation.py | 35 -- build/lib/policyengine/models/user.py | 14 - .../policyengine/models/versioned_dataset.py | 12 - build/lib/policyengine/utils/charts.py | 286 ----------- build/lib/policyengine/utils/compress.py | 20 - build/lib/policyengine/utils/datasets.py | 71 --- src/policyengine/database/__init__.py | 62 --- src/policyengine/database/aggregate.py | 110 ---- src/policyengine/database/aggregate_change.py | 131 ----- .../baseline_parameter_value_table.py | 112 ---- .../database/baseline_variable_table.py | 81 --- src/policyengine/database/database.py | 339 ------------- src/policyengine/database/dataset_table.py | 94 ---- src/policyengine/database/dynamic_table.py | 68 --- src/policyengine/database/link.py | 8 - src/policyengine/database/model_table.py | 60 --- .../database/model_version_table.py | 73 --- src/policyengine/database/parameter_table.py | 92 ---- .../database/parameter_value_table.py | 107 ---- src/policyengine/database/policy_table.py | 136 ----- src/policyengine/database/simulation_table.py | 231 --------- src/policyengine/database/table_mixin.py | 80 --- .../database/versioned_dataset_table.py | 45 -- src/policyengine/models/__init__.py | 8 +- src/policyengine/models/aggregate.py | 297 ----------- src/policyengine/models/aggregate_change.py | 127 ----- .../models/baseline_parameter_value.py | 16 - src/policyengine/models/baseline_variable.py | 12 - src/policyengine/models/dataset.py | 14 +- src/policyengine/models/dataset_version.py | 15 + src/policyengine/models/dynamic.py | 3 +- src/policyengine/models/model.py | 126 ----- src/policyengine/models/model_version.py | 14 - src/policyengine/models/parameter.py | 6 +- src/policyengine/models/policyengine_uk.py | 117 ----- src/policyengine/models/policyengine_us.py | 119 ----- src/policyengine/models/simulation.py | 25 +- src/policyengine/models/tax_benefit_model.py | 15 + .../models/tax_benefit_model_version.py | 6 +- src/policyengine/models/variable.py | 12 + src/policyengine/models/versioned_dataset.py | 12 - src/policyengine/utils/charts.py | 286 ----------- src/policyengine/utils/compress.py | 20 - src/policyengine/utils/datasets.py | 71 --- tests/test_aggregate.py | 451 ----------------- tests/test_aggregate_change.py | 479 ------------------ tests/test_database_init.py | 259 ---------- tests/test_database_models.py | 384 -------------- tests/test_database_postgres.py | 169 ------ tests/test_database_simple.py | 277 ---------- 86 files changed, 65 insertions(+), 8431 deletions(-) rename build/lib/policyengine/__init__.py => CLAUDE.md (100%) delete mode 100644 build/lib/policyengine/database/__init__.py delete mode 100644 build/lib/policyengine/database/aggregate.py delete mode 100644 build/lib/policyengine/database/aggregate_change.py delete mode 100644 build/lib/policyengine/database/baseline_parameter_value_table.py delete mode 100644 build/lib/policyengine/database/baseline_variable_table.py delete mode 100644 build/lib/policyengine/database/database.py delete mode 100644 build/lib/policyengine/database/dataset_table.py delete mode 100644 build/lib/policyengine/database/dynamic_table.py delete mode 100644 build/lib/policyengine/database/link.py delete mode 100644 build/lib/policyengine/database/model_table.py delete mode 100644 build/lib/policyengine/database/model_version_table.py delete mode 100644 build/lib/policyengine/database/parameter_table.py delete mode 100644 build/lib/policyengine/database/parameter_value_table.py delete mode 100644 build/lib/policyengine/database/policy_table.py delete mode 100644 build/lib/policyengine/database/report_element_table.py delete mode 100644 build/lib/policyengine/database/report_table.py delete mode 100644 build/lib/policyengine/database/simulation_table.py delete mode 100644 build/lib/policyengine/database/table_mixin.py delete mode 100644 build/lib/policyengine/database/user_table.py delete mode 100644 build/lib/policyengine/database/versioned_dataset_table.py delete mode 100644 build/lib/policyengine/models/__init__.py delete mode 100644 build/lib/policyengine/models/aggregate.py delete mode 100644 build/lib/policyengine/models/aggregate_change.py delete mode 100644 build/lib/policyengine/models/baseline_parameter_value.py delete mode 100644 build/lib/policyengine/models/baseline_variable.py delete mode 100644 build/lib/policyengine/models/dataset.py delete mode 100644 build/lib/policyengine/models/dynamic.py delete mode 100644 build/lib/policyengine/models/model.py delete mode 100644 build/lib/policyengine/models/parameter.py delete mode 100644 build/lib/policyengine/models/parameter_value.py delete mode 100644 build/lib/policyengine/models/policy.py delete mode 100644 build/lib/policyengine/models/policyengine_uk.py delete mode 100644 build/lib/policyengine/models/policyengine_us.py delete mode 100644 build/lib/policyengine/models/report.py delete mode 100644 build/lib/policyengine/models/report_element.py delete mode 100644 build/lib/policyengine/models/simulation.py delete mode 100644 build/lib/policyengine/models/user.py delete mode 100644 build/lib/policyengine/models/versioned_dataset.py delete mode 100644 build/lib/policyengine/utils/charts.py delete mode 100644 build/lib/policyengine/utils/compress.py delete mode 100644 build/lib/policyengine/utils/datasets.py delete mode 100644 src/policyengine/database/__init__.py delete mode 100644 src/policyengine/database/aggregate.py delete mode 100644 src/policyengine/database/aggregate_change.py delete mode 100644 src/policyengine/database/baseline_parameter_value_table.py delete mode 100644 src/policyengine/database/baseline_variable_table.py delete mode 100644 src/policyengine/database/database.py delete mode 100644 src/policyengine/database/dataset_table.py delete mode 100644 src/policyengine/database/dynamic_table.py delete mode 100644 src/policyengine/database/link.py delete mode 100644 src/policyengine/database/model_table.py delete mode 100644 src/policyengine/database/model_version_table.py delete mode 100644 src/policyengine/database/parameter_table.py delete mode 100644 src/policyengine/database/parameter_value_table.py delete mode 100644 src/policyengine/database/policy_table.py delete mode 100644 src/policyengine/database/simulation_table.py delete mode 100644 src/policyengine/database/table_mixin.py delete mode 100644 src/policyengine/database/versioned_dataset_table.py delete mode 100644 src/policyengine/models/aggregate.py delete mode 100644 src/policyengine/models/aggregate_change.py delete mode 100644 src/policyengine/models/baseline_parameter_value.py delete mode 100644 src/policyengine/models/baseline_variable.py create mode 100644 src/policyengine/models/dataset_version.py delete mode 100644 src/policyengine/models/model.py delete mode 100644 src/policyengine/models/model_version.py delete mode 100644 src/policyengine/models/policyengine_uk.py delete mode 100644 src/policyengine/models/policyengine_us.py create mode 100644 src/policyengine/models/tax_benefit_model.py rename build/lib/policyengine/models/model_version.py => src/policyengine/models/tax_benefit_model_version.py (69%) create mode 100644 src/policyengine/models/variable.py delete mode 100644 src/policyengine/models/versioned_dataset.py delete mode 100644 src/policyengine/utils/charts.py delete mode 100644 src/policyengine/utils/compress.py delete mode 100644 src/policyengine/utils/datasets.py delete mode 100644 tests/test_aggregate.py delete mode 100644 tests/test_aggregate_change.py delete mode 100644 tests/test_database_init.py delete mode 100644 tests/test_database_models.py delete mode 100644 tests/test_database_postgres.py delete mode 100644 tests/test_database_simple.py diff --git a/build/lib/policyengine/__init__.py b/CLAUDE.md similarity index 100% rename from build/lib/policyengine/__init__.py rename to CLAUDE.md diff --git a/build/lib/policyengine/database/__init__.py b/build/lib/policyengine/database/__init__.py deleted file mode 100644 index 88e1a21b..00000000 --- a/build/lib/policyengine/database/__init__.py +++ /dev/null @@ -1,68 +0,0 @@ -from .baseline_parameter_value_table import ( - BaselineParameterValueTable, - baseline_parameter_value_table_link, -) -from .baseline_variable_table import ( - BaselineVariableTable, - baseline_variable_table_link, -) -from .database import Database -from .dataset_table import DatasetTable, dataset_table_link -from .dynamic_table import DynamicTable, dynamic_table_link -from .link import TableLink - -# Import all table classes and links -from .model_table import ModelTable, model_table_link -from .model_version_table import ModelVersionTable, model_version_table_link -from .parameter_table import ParameterTable, parameter_table_link -from .parameter_value_table import ( - ParameterValueTable, - parameter_value_table_link, -) -from .policy_table import PolicyTable, policy_table_link -from .simulation_table import SimulationTable, simulation_table_link -from .versioned_dataset_table import ( - VersionedDatasetTable, - versioned_dataset_table_link, -) -from .report_table import ReportTable, report_table_link -from .report_element_table import ReportElementTable, report_element_table_link -from .aggregate import AggregateTable, aggregate_table_link -from .aggregate_change import AggregateChangeTable, aggregate_change_table_link - -__all__ = [ - "Database", - "TableLink", - # Tables - "ModelTable", - "ModelVersionTable", - "DatasetTable", - "VersionedDatasetTable", - "PolicyTable", - "DynamicTable", - "ParameterTable", - "ParameterValueTable", - "BaselineParameterValueTable", - "BaselineVariableTable", - "SimulationTable", - "ReportTable", - "ReportElementTable", - "AggregateTable", - "AggregateChangeTable", - # Links - "model_table_link", - "model_version_table_link", - "dataset_table_link", - "versioned_dataset_table_link", - "policy_table_link", - "dynamic_table_link", - "parameter_table_link", - "parameter_value_table_link", - "baseline_parameter_value_table_link", - "baseline_variable_table_link", - "simulation_table_link", - "report_table_link", - "report_element_table_link", - "aggregate_table_link", - "aggregate_change_table_link", -] diff --git a/build/lib/policyengine/database/aggregate.py b/build/lib/policyengine/database/aggregate.py deleted file mode 100644 index 44c8aacd..00000000 --- a/build/lib/policyengine/database/aggregate.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.database.link import TableLink -from policyengine.models.aggregate import Aggregate -from policyengine.models import Simulation - -if TYPE_CHECKING: - from .database import Database - - -class AggregateTable(SQLModel, table=True): - __tablename__ = "aggregates" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - simulation_id: str = Field( - foreign_key="simulations.id", ondelete="CASCADE" - ) - entity: str - variable_name: str - year: int | None = None - filter_variable_name: str | None = None - filter_variable_value: str | None = None - filter_variable_leq: float | None = None - filter_variable_geq: float | None = None - aggregate_function: str - reportelement_id: str | None = None - value: float | None = None - - @classmethod - def convert_from_model(cls, model: Aggregate, database: "Database" = None) -> "AggregateTable": - """Convert an Aggregate instance to an AggregateTable instance. - - Args: - model: The Aggregate instance to convert - database: The database instance for persisting the simulation if needed - - Returns: - An AggregateTable instance - """ - # Don't try to save the simulation here - it's already being saved - # This prevents circular references - - return cls( - id=model.id, - simulation_id=model.simulation.id if model.simulation else None, - entity=model.entity, - variable_name=model.variable_name, - year=model.year, - filter_variable_name=model.filter_variable_name, - filter_variable_value=model.filter_variable_value, - filter_variable_leq=model.filter_variable_leq, - filter_variable_geq=model.filter_variable_geq, - aggregate_function=model.aggregate_function, - reportelement_id=model.reportelement_id, - value=model.value, - ) - - def convert_to_model(self, database: "Database" = None) -> Aggregate: - """Convert this AggregateTable instance to an Aggregate instance. - - Args: - database: The database instance for resolving the simulation foreign key - - Returns: - An Aggregate instance - """ - from .simulation_table import SimulationTable - from sqlmodel import select - - # Resolve the simulation foreign key - simulation = None - if database and self.simulation_id: - sim_table = database.session.exec( - select(SimulationTable).where(SimulationTable.id == self.simulation_id) - ).first() - if sim_table: - simulation = sim_table.convert_to_model(database) - - return Aggregate( - id=self.id, - simulation=simulation, - entity=self.entity, - variable_name=self.variable_name, - year=self.year, - filter_variable_name=self.filter_variable_name, - filter_variable_value=self.filter_variable_value, - filter_variable_leq=self.filter_variable_leq, - filter_variable_geq=self.filter_variable_geq, - aggregate_function=self.aggregate_function, - reportelement_id=self.reportelement_id, - value=self.value, - ) - - -aggregate_table_link = TableLink( - model_cls=Aggregate, - table_cls=AggregateTable, -) diff --git a/build/lib/policyengine/database/aggregate_change.py b/build/lib/policyengine/database/aggregate_change.py deleted file mode 100644 index 1011ffcc..00000000 --- a/build/lib/policyengine/database/aggregate_change.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.database.link import TableLink -from policyengine.models.aggregate_change import AggregateChange - -if TYPE_CHECKING: - from .database import Database - - -class AggregateChangeTable(SQLModel, table=True): - __tablename__ = "aggregate_changes" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - baseline_simulation_id: str = Field( - foreign_key="simulations.id", ondelete="CASCADE" - ) - comparison_simulation_id: str = Field( - foreign_key="simulations.id", ondelete="CASCADE" - ) - entity: str - variable_name: str - year: int | None = None - filter_variable_name: str | None = None - filter_variable_value: str | None = None - filter_variable_leq: float | None = None - filter_variable_geq: float | None = None - aggregate_function: str - reportelement_id: str | None = None - - baseline_value: float | None = None - comparison_value: float | None = None - change: float | None = None - relative_change: float | None = None - - @classmethod - def convert_from_model(cls, model: AggregateChange, database: "Database" = None) -> "AggregateChangeTable": - """Convert an AggregateChange instance to an AggregateChangeTable instance. - - Args: - model: The AggregateChange instance to convert - database: The database instance for persisting the simulations if needed - - Returns: - An AggregateChangeTable instance - """ - return cls( - id=model.id, - baseline_simulation_id=model.baseline_simulation.id if model.baseline_simulation else None, - comparison_simulation_id=model.comparison_simulation.id if model.comparison_simulation else None, - entity=model.entity, - variable_name=model.variable_name, - year=model.year, - filter_variable_name=model.filter_variable_name, - filter_variable_value=model.filter_variable_value, - filter_variable_leq=model.filter_variable_leq, - filter_variable_geq=model.filter_variable_geq, - aggregate_function=model.aggregate_function, - reportelement_id=model.reportelement_id, - baseline_value=model.baseline_value, - comparison_value=model.comparison_value, - change=model.change, - relative_change=model.relative_change, - ) - - def convert_to_model(self, database: "Database" = None) -> AggregateChange: - """Convert this AggregateChangeTable instance to an AggregateChange instance. - - Args: - database: The database instance for resolving simulation foreign keys - - Returns: - An AggregateChange instance - """ - from .simulation_table import SimulationTable - from sqlmodel import select - - # Resolve the simulation foreign keys - baseline_simulation = None - comparison_simulation = None - - if database: - if self.baseline_simulation_id: - sim_table = database.session.exec( - select(SimulationTable).where(SimulationTable.id == self.baseline_simulation_id) - ).first() - if sim_table: - baseline_simulation = sim_table.convert_to_model(database) - - if self.comparison_simulation_id: - sim_table = database.session.exec( - select(SimulationTable).where(SimulationTable.id == self.comparison_simulation_id) - ).first() - if sim_table: - comparison_simulation = sim_table.convert_to_model(database) - - return AggregateChange( - id=self.id, - baseline_simulation=baseline_simulation, - comparison_simulation=comparison_simulation, - entity=self.entity, - variable_name=self.variable_name, - year=self.year, - filter_variable_name=self.filter_variable_name, - filter_variable_value=self.filter_variable_value, - filter_variable_leq=self.filter_variable_leq, - filter_variable_geq=self.filter_variable_geq, - aggregate_function=self.aggregate_function, - reportelement_id=self.reportelement_id, - baseline_value=self.baseline_value, - comparison_value=self.comparison_value, - change=self.change, - relative_change=self.relative_change, - ) - - -aggregate_change_table_link = TableLink( - model_cls=AggregateChange, - table_cls=AggregateChangeTable, -) \ No newline at end of file diff --git a/build/lib/policyengine/database/baseline_parameter_value_table.py b/build/lib/policyengine/database/baseline_parameter_value_table.py deleted file mode 100644 index 6485223c..00000000 --- a/build/lib/policyengine/database/baseline_parameter_value_table.py +++ /dev/null @@ -1,112 +0,0 @@ -from datetime import datetime -from typing import Any -from uuid import uuid4 - -from sqlmodel import JSON, Column, Field, SQLModel -from typing import TYPE_CHECKING - -from policyengine.models import ModelVersion, Parameter, BaselineParameterValue - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class BaselineParameterValueTable(SQLModel, table=True): - __tablename__ = "baseline_parameter_values" - __table_args__ = ({"extend_existing": True},) - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - parameter_id: str = Field(nullable=False) # Part of composite foreign key - model_id: str = Field(nullable=False) # Part of composite foreign key - model_version_id: str = Field( - foreign_key="model_versions.id", ondelete="CASCADE" - ) - value: Any | None = Field( - default=None, sa_column=Column(JSON) - ) # JSON field for any type - start_date: datetime = Field(nullable=False) - end_date: datetime | None = Field(default=None) - - @classmethod - def convert_from_model(cls, model: BaselineParameterValue, database: "Database" = None) -> "BaselineParameterValueTable": - """Convert a BaselineParameterValue instance to a BaselineParameterValueTable instance.""" - import math - - # Ensure foreign objects are persisted if database is provided - if database: - if model.parameter: - database.set(model.parameter, commit=False) - if model.model_version: - database.set(model.model_version, commit=False) - - # Handle special float values - value = model.value - if isinstance(value, float): - if math.isinf(value): - value = "Infinity" if value > 0 else "-Infinity" - elif math.isnan(value): - value = "NaN" - - return cls( - id=model.id, - parameter_id=model.parameter.id if model.parameter else None, - model_id=model.parameter.model.id if model.parameter and model.parameter.model else None, - model_version_id=model.model_version.id if model.model_version else None, - value=value, - start_date=model.start_date, - end_date=model.end_date, - ) - - def convert_to_model(self, database: "Database" = None) -> BaselineParameterValue: - """Convert this BaselineParameterValueTable instance to a BaselineParameterValue instance.""" - from .parameter_table import ParameterTable - from .model_version_table import ModelVersionTable - from sqlmodel import select - - # Resolve foreign keys - parameter = None - model_version = None - - if database: - if self.parameter_id and self.model_id: - param_table = database.session.exec( - select(ParameterTable).where( - ParameterTable.id == self.parameter_id, - ParameterTable.model_id == self.model_id - ) - ).first() - if param_table: - parameter = param_table.convert_to_model(database) - - if self.model_version_id: - version_table = database.session.exec( - select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) - ).first() - if version_table: - model_version = version_table.convert_to_model(database) - - # Handle special string values - value = self.value - if value == "Infinity": - value = float("inf") - elif value == "-Infinity": - value = float("-inf") - elif value == "NaN": - value = float("nan") - - return BaselineParameterValue( - id=self.id, - parameter=parameter, - model_version=model_version, - value=value, - start_date=self.start_date, - end_date=self.end_date, - ) - - -baseline_parameter_value_table_link = TableLink( - model_cls=BaselineParameterValue, - table_cls=BaselineParameterValueTable, -) diff --git a/build/lib/policyengine/database/baseline_variable_table.py b/build/lib/policyengine/database/baseline_variable_table.py deleted file mode 100644 index e7773c80..00000000 --- a/build/lib/policyengine/database/baseline_variable_table.py +++ /dev/null @@ -1,81 +0,0 @@ -from sqlmodel import Field, SQLModel -from typing import TYPE_CHECKING - -from policyengine.models import ModelVersion, BaselineVariable - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class BaselineVariableTable(SQLModel, table=True): - __tablename__ = "baseline_variables" - __table_args__ = ({"extend_existing": True},) - - id: str = Field(primary_key=True) # Variable name - model_id: str = Field( - primary_key=True, foreign_key="models.id" - ) # Part of composite key - model_version_id: str = Field( - foreign_key="model_versions.id", ondelete="CASCADE" - ) - entity: str = Field(nullable=False) - label: str | None = Field(default=None) - description: str | None = Field(default=None) - data_type: str | None = Field(default=None) # Data type name - - @classmethod - def convert_from_model(cls, model: BaselineVariable, database: "Database" = None) -> "BaselineVariableTable": - """Convert a BaselineVariable instance to a BaselineVariableTable instance.""" - # Ensure foreign objects are persisted if database is provided - if database and model.model_version: - database.set(model.model_version, commit=False) - - return cls( - id=model.id, - model_id=model.model_version.model.id if model.model_version and model.model_version.model else None, - model_version_id=model.model_version.id if model.model_version else None, - entity=model.entity, - label=model.label, - description=model.description, - data_type=model.data_type.__name__ if model.data_type else None, - ) - - def convert_to_model(self, database: "Database" = None) -> BaselineVariable: - """Convert this BaselineVariableTable instance to a BaselineVariable instance.""" - from .model_version_table import ModelVersionTable - from sqlmodel import select - - # Resolve foreign keys - model_version = None - - if database and self.model_version_id: - version_table = database.session.exec( - select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) - ).first() - if version_table: - model_version = version_table.convert_to_model(database) - - # Convert data_type string back to type - data_type = None - if self.data_type: - try: - data_type = eval(self.data_type) - except: - data_type = None - - return BaselineVariable( - id=self.id, - model_version=model_version, - entity=self.entity, - label=self.label, - description=self.description, - data_type=data_type, - ) - - -baseline_variable_table_link = TableLink( - model_cls=BaselineVariable, - table_cls=BaselineVariableTable, -) diff --git a/build/lib/policyengine/database/database.py b/build/lib/policyengine/database/database.py deleted file mode 100644 index 2ae77e1c..00000000 --- a/build/lib/policyengine/database/database.py +++ /dev/null @@ -1,301 +0,0 @@ -from typing import Any - -from sqlmodel import Session, SQLModel - -from .aggregate import aggregate_table_link -from .baseline_parameter_value_table import baseline_parameter_value_table_link -from .baseline_variable_table import baseline_variable_table_link -from .dataset_table import dataset_table_link -from .dynamic_table import dynamic_table_link -from .link import TableLink - -# Import all table links -from .model_table import model_table_link -from .model_version_table import model_version_table_link -from .parameter_table import parameter_table_link -from .parameter_value_table import parameter_value_table_link -from .policy_table import policy_table_link -from .report_element_table import report_element_table_link -from .report_table import report_table_link -from .simulation_table import simulation_table_link -from .user_table import user_table_link -from .versioned_dataset_table import versioned_dataset_table_link - - -class Database: - url: str - - _model_table_links: list[TableLink] = [] - - def __init__(self, url: str): - self.url = url - self.engine = self._create_engine() - self.session = Session(self.engine) - - for link in [ - model_table_link, - model_version_table_link, - dataset_table_link, - versioned_dataset_table_link, - policy_table_link, - dynamic_table_link, - parameter_table_link, - parameter_value_table_link, - baseline_parameter_value_table_link, - baseline_variable_table_link, - simulation_table_link, - aggregate_table_link, - user_table_link, - report_table_link, - report_element_table_link, - ]: - self.register_table(link) - - def _create_engine(self): - from sqlmodel import create_engine - - return create_engine(self.url, echo=False) - - def create_tables(self): - """Create all database tables.""" - SQLModel.metadata.create_all(self.engine) - - def drop_tables(self): - """Drop all database tables.""" - SQLModel.metadata.drop_all(self.engine) - - def reset(self): - """Drop and recreate all tables.""" - self.drop_tables() - self.create_tables() - - def __enter__(self): - """Context manager entry - creates a session.""" - self.session = Session(self.engine) - return self.session - - def __exit__(self, exc_type, exc_val, exc_tb): - """Context manager exit - closes the session.""" - if exc_type: - self.session.rollback() - else: - self.session.commit() - self.session.close() - - def register_table(self, link: TableLink): - self._model_table_links.append(link) - # Create the table if not exists - link.table_cls.metadata.create_all(self.engine) - - def get(self, model_cls: type, **kwargs): - """Get a model instance from the database by its attributes.""" - from sqlmodel import select - - # Find the table class for this model - table_link = next( - ( - link - for link in self._model_table_links - if link.model_cls == model_cls - ), - None, - ) - - if table_link is None: - return None - - # Query the database - statement = select(table_link.table_cls).filter_by(**kwargs) - result = self.session.exec(statement).first() - - if result is None: - return None - - # Use the table's convert_to_model method - return result.convert_to_model(self) - - def set(self, object: Any, commit: bool = True): - """Save or update a model instance in the database.""" - from sqlmodel import select - from sqlalchemy.inspection import inspect - - # Find the table class for this model - table_link = next( - ( - link - for link in self._model_table_links - if link.model_cls is type(object) - ), - None, - ) - - if table_link is None: - return - - # Convert model to table instance - table_obj = table_link.table_cls.convert_from_model(object, self) - - # Get primary key columns - mapper = inspect(table_link.table_cls) - pk_cols = [col.name for col in mapper.primary_key] - - # Build query to check if exists - query = select(table_link.table_cls) - for pk_col in pk_cols: - query = query.where( - getattr(table_link.table_cls, pk_col) == getattr(table_obj, pk_col) - ) - - existing = self.session.exec(query).first() - - if existing: - # Update existing record - for key, value in table_obj.model_dump().items(): - setattr(existing, key, value) - self.session.add(existing) - else: - self.session.add(table_obj) - - if commit: - self.session.commit() - - def register_model_version(self, model_version): - """Register a model version with its model and seed objects. - This replaces all existing parameters, baseline parameter values, - and baseline variables for this model version.""" - # Add or update the model directly to avoid conflicts - from policyengine.utils.compress import compress_data - - from .baseline_parameter_value_table import BaselineParameterValueTable - from .baseline_variable_table import BaselineVariableTable - from .model_table import ModelTable - from .model_version_table import ModelVersionTable - from .parameter_table import ParameterTable - - existing_model = ( - self.session.query(ModelTable) - .filter(ModelTable.id == model_version.model.id) - .first() - ) - if not existing_model: - model_table = ModelTable( - id=model_version.model.id, - name=model_version.model.name, - description=model_version.model.description, - simulation_function=compress_data( - model_version.model.simulation_function - ), - ) - self.session.add(model_table) - self.session.flush() - - # Add or update the model version - existing_version = ( - self.session.query(ModelVersionTable) - .filter(ModelVersionTable.id == model_version.id) - .first() - ) - if not existing_version: - version_table = ModelVersionTable( - id=model_version.id, - model_id=model_version.model.id, - version=model_version.version, - description=model_version.description, - created_at=model_version.created_at, - ) - self.session.add(version_table) - self.session.flush() - - # Get seed objects from the model - seed_objects = model_version.model.create_seed_objects(model_version) - - # Delete ALL existing seed data for this model (not just this version) - # This ensures we start fresh with the new version's data - # Order matters due to foreign key constraints - - # First delete baseline parameter values (they reference parameters) - self.session.query(BaselineParameterValueTable).filter( - BaselineParameterValueTable.model_id == model_version.model.id - ).delete() - - # Then delete baseline variables for this model - self.session.query(BaselineVariableTable).filter( - BaselineVariableTable.model_id == model_version.model.id - ).delete() - - # Finally delete all parameters for this model - self.session.query(ParameterTable).filter( - ParameterTable.model_id == model_version.model.id - ).delete() - - self.session.commit() - - # Add all parameters first - for parameter in seed_objects.parameters: - # We need to add directly to session to avoid the autoflush issue - from .parameter_table import ParameterTable - - param_table = ParameterTable( - id=parameter.id, - model_id=parameter.model.id, # Now required as part of composite key - description=parameter.description, - data_type=parameter.data_type.__name__ - if parameter.data_type - else None, - label=parameter.label, - unit=parameter.unit, - ) - self.session.add(param_table) - - # Flush parameters to database so they exist for foreign key constraints - self.session.flush() - - # Add all baseline parameter values - for baseline_param_value in seed_objects.baseline_parameter_values: - import math - from uuid import uuid4 - - from .baseline_parameter_value_table import ( - BaselineParameterValueTable, - ) - - # Handle special float values that JSON doesn't support - value = baseline_param_value.value - if isinstance(value, float): - if math.isinf(value): - value = "Infinity" if value > 0 else "-Infinity" - elif math.isnan(value): - value = "NaN" - - bpv_table = BaselineParameterValueTable( - id=str(uuid4()), - parameter_id=baseline_param_value.parameter.id, - model_id=baseline_param_value.parameter.model.id, # Add model_id - model_version_id=baseline_param_value.model_version.id, - value=value, - start_date=baseline_param_value.start_date, - end_date=baseline_param_value.end_date, - ) - self.session.add(bpv_table) - - # Add all baseline variables - for baseline_variable in seed_objects.baseline_variables: - from .baseline_variable_table import BaselineVariableTable - - bv_table = BaselineVariableTable( - id=baseline_variable.id, - model_id=baseline_variable.model_version.model.id, # Add model_id - model_version_id=baseline_variable.model_version.id, - entity=baseline_variable.entity, - label=baseline_variable.label, - description=baseline_variable.description, - data_type=(lambda bv: compress_data(bv.data_type))( - baseline_variable - ) - if baseline_variable.data_type - else None, - ) - self.session.add(bv_table) - - # Commit everything at once - self.session.commit() diff --git a/build/lib/policyengine/database/dataset_table.py b/build/lib/policyengine/database/dataset_table.py deleted file mode 100644 index cf22cda8..00000000 --- a/build/lib/policyengine/database/dataset_table.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.models import Dataset, Model, VersionedDataset -from policyengine.utils.compress import compress_data, decompress_data - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class DatasetTable(SQLModel, table=True): - __tablename__ = "datasets" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - name: str = Field(nullable=False) - description: str | None = Field(default=None) - version: str | None = Field(default=None) - versioned_dataset_id: str | None = Field( - default=None, foreign_key="versioned_datasets.id", ondelete="SET NULL" - ) - year: int | None = Field(default=None) - data: bytes | None = Field(default=None) - model_id: str | None = Field( - default=None, foreign_key="models.id", ondelete="SET NULL" - ) - - @classmethod - def convert_from_model(cls, model: Dataset, database: "Database" = None) -> "DatasetTable": - """Convert a Dataset instance to a DatasetTable instance. - - Args: - model: The Dataset instance to convert - database: The database instance for persisting foreign objects if needed - - Returns: - A DatasetTable instance - """ - # Ensure foreign objects are persisted if database is provided - if database: - if model.versioned_dataset: - database.set(model.versioned_dataset, commit=False) - if model.model: - database.set(model.model, commit=False) - - return cls( - id=model.id, - name=model.name, - description=model.description, - version=model.version, - versioned_dataset_id=model.versioned_dataset.id if model.versioned_dataset else None, - year=model.year, - data=compress_data(model.data) if model.data else None, - model_id=model.model.id if model.model else None, - ) - - def convert_to_model(self, database: "Database" = None) -> Dataset: - """Convert this DatasetTable instance to a Dataset instance. - - Args: - database: The database instance for resolving foreign keys - - Returns: - A Dataset instance - """ - # Resolve foreign keys - versioned_dataset = None - model = None - - if database: - if self.versioned_dataset_id: - versioned_dataset = database.get(VersionedDataset, id=self.versioned_dataset_id) - if self.model_id: - model = database.get(Model, id=self.model_id) - - return Dataset( - id=self.id, - name=self.name, - description=self.description, - version=self.version, - versioned_dataset=versioned_dataset, - year=self.year, - data=decompress_data(self.data) if self.data else None, - model=model, - ) - - -dataset_table_link = TableLink( - model_cls=Dataset, - table_cls=DatasetTable, -) diff --git a/build/lib/policyengine/database/dynamic_table.py b/build/lib/policyengine/database/dynamic_table.py deleted file mode 100644 index 086e6bd9..00000000 --- a/build/lib/policyengine/database/dynamic_table.py +++ /dev/null @@ -1,68 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.models import Dynamic -from policyengine.utils.compress import compress_data, decompress_data - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class DynamicTable(SQLModel, table=True): - __tablename__ = "dynamics" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - name: str = Field(nullable=False) - description: str | None = Field(default=None) - simulation_modifier: bytes | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - @classmethod - def convert_from_model(cls, model: Dynamic, database: "Database" = None) -> "DynamicTable": - """Convert a Dynamic instance to a DynamicTable instance. - - Args: - model: The Dynamic instance to convert - database: The database instance (not used for this table) - - Returns: - A DynamicTable instance - """ - return cls( - id=model.id, - name=model.name, - description=model.description, - simulation_modifier=compress_data(model.simulation_modifier) if model.simulation_modifier else None, - created_at=model.created_at, - updated_at=model.updated_at, - ) - - def convert_to_model(self, database: "Database" = None) -> Dynamic: - """Convert this DynamicTable instance to a Dynamic instance. - - Args: - database: The database instance (not used for this table) - - Returns: - A Dynamic instance - """ - return Dynamic( - id=self.id, - name=self.name, - description=self.description, - simulation_modifier=decompress_data(self.simulation_modifier) if self.simulation_modifier else None, - created_at=self.created_at, - updated_at=self.updated_at, - ) - - -dynamic_table_link = TableLink( - model_cls=Dynamic, - table_cls=DynamicTable, -) diff --git a/build/lib/policyengine/database/link.py b/build/lib/policyengine/database/link.py deleted file mode 100644 index 2bb1a041..00000000 --- a/build/lib/policyengine/database/link.py +++ /dev/null @@ -1,8 +0,0 @@ -from pydantic import BaseModel -from sqlmodel import SQLModel - - -class TableLink(BaseModel): - """Simple registry mapping model classes to table classes.""" - model_cls: type[BaseModel] - table_cls: type[SQLModel] diff --git a/build/lib/policyengine/database/model_table.py b/build/lib/policyengine/database/model_table.py deleted file mode 100644 index 220238c8..00000000 --- a/build/lib/policyengine/database/model_table.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import TYPE_CHECKING - -from sqlmodel import Field, SQLModel - -from policyengine.models import Model -from policyengine.utils.compress import compress_data, decompress_data - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ModelTable(SQLModel, table=True, extend_existing=True): - __tablename__ = "models" - - id: str = Field(primary_key=True) - name: str = Field(nullable=False) - description: str | None = Field(default=None) - simulation_function: bytes - - @classmethod - def convert_from_model(cls, model: Model, database: "Database" = None) -> "ModelTable": - """Convert a Model instance to a ModelTable instance. - - Args: - model: The Model instance to convert - database: The database instance (not used for this table) - - Returns: - A ModelTable instance - """ - return cls( - id=model.id, - name=model.name, - description=model.description, - simulation_function=compress_data(model.simulation_function), - ) - - def convert_to_model(self, database: "Database" = None) -> Model: - """Convert this ModelTable instance to a Model instance. - - Args: - database: The database instance (not used for this table) - - Returns: - A Model instance - """ - return Model( - id=self.id, - name=self.name, - description=self.description, - simulation_function=decompress_data(self.simulation_function), - ) - - -model_table_link = TableLink( - model_cls=Model, - table_cls=ModelTable, -) diff --git a/build/lib/policyengine/database/model_version_table.py b/build/lib/policyengine/database/model_version_table.py deleted file mode 100644 index 86d19fed..00000000 --- a/build/lib/policyengine/database/model_version_table.py +++ /dev/null @@ -1,73 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.models import Model, ModelVersion - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ModelVersionTable(SQLModel, table=True): - __tablename__ = "model_versions" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - model_id: str = Field(foreign_key="models.id", ondelete="CASCADE") - version: str = Field(nullable=False) - description: str | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - - @classmethod - def convert_from_model(cls, model: ModelVersion, database: "Database" = None) -> "ModelVersionTable": - """Convert a ModelVersion instance to a ModelVersionTable instance. - - Args: - model: The ModelVersion instance to convert - database: The database instance for persisting the model if needed - - Returns: - A ModelVersionTable instance - """ - # Ensure the Model is persisted if database is provided - if database and model.model: - database.set(model.model, commit=False) - - return cls( - id=model.id, - model_id=model.model.id if model.model else None, - version=model.version, - description=model.description, - created_at=model.created_at, - ) - - def convert_to_model(self, database: "Database" = None) -> ModelVersion: - """Convert this ModelVersionTable instance to a ModelVersion instance. - - Args: - database: The database instance for resolving the model foreign key - - Returns: - A ModelVersion instance - """ - # Resolve the model foreign key - model = None - if database and self.model_id: - model = database.get(Model, id=self.model_id) - - return ModelVersion( - id=self.id, - model=model, - version=self.version, - description=self.description, - created_at=self.created_at, - ) - - -model_version_table_link = TableLink( - model_cls=ModelVersion, - table_cls=ModelVersionTable, -) diff --git a/build/lib/policyengine/database/parameter_table.py b/build/lib/policyengine/database/parameter_table.py deleted file mode 100644 index aef88e5a..00000000 --- a/build/lib/policyengine/database/parameter_table.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import TYPE_CHECKING - -from sqlmodel import Field, SQLModel - -from policyengine.models import Model, Parameter - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ParameterTable(SQLModel, table=True): - __tablename__ = "parameters" - __table_args__ = ({"extend_existing": True},) - - id: str = Field(primary_key=True) # Parameter name - model_id: str = Field( - primary_key=True, foreign_key="models.id" - ) # Part of composite key - description: str | None = Field(default=None) - data_type: str | None = Field(nullable=True) # Data type name - label: str | None = Field(default=None) - unit: str | None = Field(default=None) - - @classmethod - def convert_from_model(cls, model: Parameter, database: "Database" = None) -> "ParameterTable": - """Convert a Parameter instance to a ParameterTable instance. - - Args: - model: The Parameter instance to convert - database: The database instance for persisting the model if needed - - Returns: - A ParameterTable instance - """ - # Ensure the Model is persisted if database is provided - if database and model.model: - database.set(model.model, commit=False) - - return cls( - id=model.id, - model_id=model.model.id if model.model else None, - description=model.description, - data_type=model.data_type.__name__ if model.data_type else None, - label=model.label, - unit=model.unit, - ) - - def convert_to_model(self, database: "Database" = None) -> Parameter: - """Convert this ParameterTable instance to a Parameter instance. - - Args: - database: The database instance for resolving the model foreign key - - Returns: - A Parameter instance - """ - from .model_table import ModelTable - from sqlmodel import select - - # Resolve the model foreign key - model = None - if database and self.model_id: - model_table = database.session.exec( - select(ModelTable).where(ModelTable.id == self.model_id) - ).first() - if model_table: - model = model_table.convert_to_model(database) - - # Convert data_type string back to type - data_type = None - if self.data_type: - try: - data_type = eval(self.data_type) - except: - data_type = None - - return Parameter( - id=self.id, - description=self.description, - data_type=data_type, - model=model, - label=self.label, - unit=self.unit, - ) - - -parameter_table_link = TableLink( - model_cls=Parameter, - table_cls=ParameterTable, -) diff --git a/build/lib/policyengine/database/parameter_value_table.py b/build/lib/policyengine/database/parameter_value_table.py deleted file mode 100644 index 7bd02d0a..00000000 --- a/build/lib/policyengine/database/parameter_value_table.py +++ /dev/null @@ -1,108 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING, Any -from uuid import uuid4 - -from sqlmodel import JSON, Column, Field, SQLModel - -from policyengine.models import Parameter, ParameterValue - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ParameterValueTable(SQLModel, table=True): - __tablename__ = "parameter_values" - __table_args__ = ({"extend_existing": True},) - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - parameter_id: str = Field(nullable=False) # Part of composite foreign key - model_id: str = Field(nullable=False) # Part of composite foreign key - policy_id: str | None = Field(default=None, foreign_key="policies.id", ondelete="CASCADE") # Link to policy - value: Any | None = Field( - default=None, sa_column=Column(JSON) - ) # JSON field for any type - start_date: datetime = Field(nullable=False) - end_date: datetime | None = Field(default=None) - - @classmethod - def convert_from_model(cls, model: ParameterValue, database: "Database" = None) -> "ParameterValueTable": - """Convert a ParameterValue instance to a ParameterValueTable instance. - - Args: - model: The ParameterValue instance to convert - database: The database instance for persisting the parameter if needed - - Returns: - A ParameterValueTable instance - """ - import math - - # Ensure the Parameter is persisted if database is provided - if database and model.parameter: - database.set(model.parameter, commit=False) - - # Handle special float values - value = model.value - if isinstance(value, float): - if math.isinf(value): - value = "Infinity" if value > 0 else "-Infinity" - elif math.isnan(value): - value = "NaN" - - return cls( - id=model.id, - parameter_id=model.parameter.id if model.parameter else None, - model_id=model.parameter.model.id if model.parameter and model.parameter.model else None, - value=value, - start_date=model.start_date, - end_date=model.end_date, - ) - - def convert_to_model(self, database: "Database" = None) -> ParameterValue: - """Convert this ParameterValueTable instance to a ParameterValue instance. - - Args: - database: The database instance for resolving the parameter foreign key - - Returns: - A ParameterValue instance - """ - from .parameter_table import ParameterTable - from sqlmodel import select - - # Resolve the parameter foreign key - parameter = None - if database and self.parameter_id and self.model_id: - param_table = database.session.exec( - select(ParameterTable).where( - ParameterTable.id == self.parameter_id, - ParameterTable.model_id == self.model_id - ) - ).first() - if param_table: - parameter = param_table.convert_to_model(database) - - # Handle special string values - value = self.value - if value == "Infinity": - value = float("inf") - elif value == "-Infinity": - value = float("-inf") - elif value == "NaN": - value = float("nan") - - return ParameterValue( - id=self.id, - parameter=parameter, - value=value, - start_date=self.start_date, - end_date=self.end_date, - ) - - -parameter_value_table_link = TableLink( - model_cls=ParameterValue, - table_cls=ParameterValueTable, -) diff --git a/build/lib/policyengine/database/policy_table.py b/build/lib/policyengine/database/policy_table.py deleted file mode 100644 index 0ae381e4..00000000 --- a/build/lib/policyengine/database/policy_table.py +++ /dev/null @@ -1,136 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.models import Policy -from policyengine.utils.compress import compress_data, decompress_data - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class PolicyTable(SQLModel, table=True): - __tablename__ = "policies" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - name: str = Field(nullable=False) - description: str | None = Field(default=None) - simulation_modifier: bytes | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - @classmethod - def convert_from_model(cls, model: Policy, database: "Database" = None) -> "PolicyTable": - """Convert a Policy instance to a PolicyTable instance. - - Args: - model: The Policy instance to convert - database: The database instance for persisting nested objects - - Returns: - A PolicyTable instance - """ - policy_table = cls( - id=model.id, - name=model.name, - description=model.description, - simulation_modifier=compress_data(model.simulation_modifier) if model.simulation_modifier else None, - created_at=model.created_at, - updated_at=model.updated_at, - ) - - # Handle nested parameter values if database is provided - if database and model.parameter_values: - from .parameter_value_table import ParameterValueTable - from sqlmodel import select - - # First ensure the policy table is saved to the database - # This is necessary so the foreign key constraint is satisfied - # Check if it already exists - existing_policy = database.session.exec( - select(PolicyTable).where(PolicyTable.id == model.id) - ).first() - - if not existing_policy: - database.session.add(policy_table) - database.session.flush() - - # Track which parameter value IDs we want to keep - desired_pv_ids = {pv.id for pv in model.parameter_values} - - # Delete only parameter values linked to this policy that are NOT in the new list - existing_pvs = database.session.exec( - select(ParameterValueTable).where(ParameterValueTable.policy_id == model.id) - ).all() - for pv in existing_pvs: - if pv.id not in desired_pv_ids: - database.session.delete(pv) - - # Now save/update the parameter values - for param_value in model.parameter_values: - # Check if this parameter value already exists in the database - existing_pv = database.session.exec( - select(ParameterValueTable).where(ParameterValueTable.id == param_value.id) - ).first() - - if existing_pv: - # Update existing parameter value - pv_table = ParameterValueTable.convert_from_model(param_value, database) - existing_pv.parameter_id = pv_table.parameter_id - existing_pv.model_id = pv_table.model_id - existing_pv.policy_id = model.id - existing_pv.value = pv_table.value - existing_pv.start_date = pv_table.start_date - existing_pv.end_date = pv_table.end_date - else: - # Create new parameter value - pv_table = ParameterValueTable.convert_from_model(param_value, database) - pv_table.policy_id = model.id # Link to this policy - database.session.add(pv_table) - database.session.flush() - - return policy_table - - def convert_to_model(self, database: "Database" = None) -> Policy: - """Convert this PolicyTable instance to a Policy instance. - - Args: - database: The database instance for loading nested objects - - Returns: - A Policy instance - """ - # Load nested parameter values if database is provided - parameter_values = [] - if database: - from .parameter_value_table import ParameterValueTable - from sqlmodel import select - - # Query for all parameter values linked to this policy - pv_tables = database.session.exec( - select(ParameterValueTable).where(ParameterValueTable.policy_id == self.id) - ).all() - - # Convert each one to a model - for pv_table in pv_tables: - parameter_values.append(pv_table.convert_to_model(database)) - - return Policy( - id=self.id, - name=self.name, - description=self.description, - parameter_values=parameter_values, - simulation_modifier=decompress_data(self.simulation_modifier) if self.simulation_modifier else None, - created_at=self.created_at, - updated_at=self.updated_at, - ) - - -policy_table_link = TableLink( - model_cls=Policy, - table_cls=PolicyTable, -) diff --git a/build/lib/policyengine/database/report_element_table.py b/build/lib/policyengine/database/report_element_table.py deleted file mode 100644 index cc69e83e..00000000 --- a/build/lib/policyengine/database/report_element_table.py +++ /dev/null @@ -1,106 +0,0 @@ -import uuid -from datetime import datetime - -from sqlmodel import Field, SQLModel, Column, JSON -from typing import TYPE_CHECKING - -from policyengine.models.report_element import ReportElement - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ReportElementTable(SQLModel, table=True, extend_existing=True): - __tablename__ = "report_elements" - - id: str = Field( - primary_key=True, default_factory=lambda: str(uuid.uuid4()) - ) - label: str = Field(nullable=False) - type: str = Field(nullable=False) # "chart" or "markdown" - - # Data source - data_table: str | None = Field(default=None) # "aggregates" or "aggregate_changes" - - # Chart configuration - chart_type: str | None = Field( - default=None - ) # "bar", "line", "scatter", "area", "pie" - x_axis_variable: str | None = Field(default=None) - y_axis_variable: str | None = Field(default=None) - group_by: str | None = Field(default=None) - color_by: str | None = Field(default=None) - size_by: str | None = Field(default=None) - - # Markdown specific - markdown_content: str | None = Field(default=None) - - # Metadata - report_id: str | None = Field(default=None, foreign_key="reports.id") - user_id: str | None = Field(default=None, foreign_key="users.id") - model_version_id: str | None = Field(default=None, foreign_key="model_versions.id") - position: int | None = Field(default=None) - visible: bool | None = Field(default=True) - custom_config: dict | None = Field(default=None, sa_column=Column(JSON)) - report_element_metadata: dict | None = Field(default=None, sa_column=Column(JSON)) - created_at: datetime = Field(default_factory=datetime.utcnow) - updated_at: datetime = Field(default_factory=datetime.utcnow) - - @classmethod - def convert_from_model(cls, model: ReportElement, database: "Database" = None) -> "ReportElementTable": - """Convert a ReportElement instance to a ReportElementTable instance.""" - return cls( - id=model.id, - label=model.label, - type=model.type, - data_table=model.data_table, - chart_type=model.chart_type, - x_axis_variable=model.x_axis_variable, - y_axis_variable=model.y_axis_variable, - group_by=model.group_by, - color_by=model.color_by, - size_by=model.size_by, - markdown_content=model.markdown_content, - report_id=model.report_id, - user_id=model.user_id, - model_version_id=model.model_version_id, - position=model.position, - visible=model.visible, - custom_config=model.custom_config, - report_element_metadata=model.report_element_metadata, - created_at=model.created_at, - updated_at=model.updated_at, - ) - - def convert_to_model(self, database: "Database" = None) -> ReportElement: - """Convert this ReportElementTable instance to a ReportElement instance.""" - return ReportElement( - id=self.id, - label=self.label, - type=self.type, - data_table=self.data_table, - chart_type=self.chart_type, - x_axis_variable=self.x_axis_variable, - y_axis_variable=self.y_axis_variable, - group_by=self.group_by, - color_by=self.color_by, - size_by=self.size_by, - markdown_content=self.markdown_content, - report_id=self.report_id, - user_id=self.user_id, - model_version_id=self.model_version_id, - position=self.position, - visible=self.visible, - custom_config=self.custom_config, - report_element_metadata=self.report_element_metadata, - created_at=self.created_at, - updated_at=self.updated_at, - ) - - -report_element_table_link = TableLink( - model_cls=ReportElement, - table_cls=ReportElementTable, -) diff --git a/build/lib/policyengine/database/report_table.py b/build/lib/policyengine/database/report_table.py deleted file mode 100644 index 79c11cf0..00000000 --- a/build/lib/policyengine/database/report_table.py +++ /dev/null @@ -1,120 +0,0 @@ -import uuid -from datetime import datetime - -from sqlmodel import Field, SQLModel -from typing import TYPE_CHECKING - -from policyengine.models.report import Report - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ReportTable(SQLModel, table=True, extend_existing=True): - __tablename__ = "reports" - - id: str = Field( - primary_key=True, default_factory=lambda: str(uuid.uuid4()) - ) - label: str = Field(nullable=False) - created_at: datetime = Field(default_factory=datetime.utcnow) - - @classmethod - def convert_from_model(cls, model: Report, database: "Database" = None) -> "ReportTable": - """Convert a Report instance to a ReportTable instance.""" - report_table = cls( - id=model.id, - label=model.label, - created_at=model.created_at, - ) - - # Handle nested report elements if database is provided - if database and model.elements: - from .report_element_table import ReportElementTable - from sqlmodel import select - - # First ensure the report table is saved to the database - # This is necessary so the foreign key constraint is satisfied - # Check if it already exists - existing_report = database.session.exec( - select(ReportTable).where(ReportTable.id == model.id) - ).first() - - if not existing_report: - database.session.add(report_table) - database.session.flush() - - # Track which element IDs we want to keep - desired_elem_ids = {elem.id for elem in model.elements} - - # Delete only elements linked to this report that are NOT in the new list - existing_elems = database.session.exec( - select(ReportElementTable).where(ReportElementTable.report_id == model.id) - ).all() - for elem in existing_elems: - if elem.id not in desired_elem_ids: - database.session.delete(elem) - - # Now save/update the elements - for i, element in enumerate(model.elements): - # Check if this element already exists in the database - existing_elem = database.session.exec( - select(ReportElementTable).where(ReportElementTable.id == element.id) - ).first() - - if existing_elem: - # Update existing element - elem_table = ReportElementTable.convert_from_model(element, database) - existing_elem.report_id = model.id - existing_elem.position = i - existing_elem.label = elem_table.label - existing_elem.type = elem_table.type - existing_elem.markdown_content = elem_table.markdown_content - existing_elem.chart_type = elem_table.chart_type - existing_elem.x_axis_variable = elem_table.x_axis_variable - existing_elem.y_axis_variable = elem_table.y_axis_variable - existing_elem.baseline_simulation_id = elem_table.baseline_simulation_id - existing_elem.reform_simulation_id = elem_table.reform_simulation_id - else: - # Create new element - elem_table = ReportElementTable.convert_from_model(element, database) - elem_table.report_id = model.id # Link to this report - elem_table.position = i # Maintain order - database.session.add(elem_table) - database.session.flush() - - return report_table - - def convert_to_model(self, database: "Database" = None) -> Report: - """Convert this ReportTable instance to a Report instance.""" - # Load nested report elements if database is provided - elements = [] - if database: - from .report_element_table import ReportElementTable - from sqlmodel import select - - # Query for all elements linked to this report, ordered by position - elem_tables = database.session.exec( - select(ReportElementTable) - .where(ReportElementTable.report_id == self.id) - .order_by(ReportElementTable.position) - ).all() - - # Convert each one to a model - for elem_table in elem_tables: - elements.append(elem_table.convert_to_model(database)) - - return Report( - id=self.id, - label=self.label, - created_at=self.created_at, - elements=elements, - ) - - -report_table_link = TableLink( - model_cls=Report, - table_cls=ReportTable, -) diff --git a/build/lib/policyengine/database/simulation_table.py b/build/lib/policyengine/database/simulation_table.py deleted file mode 100644 index de45a419..00000000 --- a/build/lib/policyengine/database/simulation_table.py +++ /dev/null @@ -1,225 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.models import Dataset, Dynamic, Model, ModelVersion, Policy, Simulation -from policyengine.utils.compress import compress_data, decompress_data - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class SimulationTable(SQLModel, table=True): - __tablename__ = "simulations" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - policy_id: str | None = Field( - default=None, foreign_key="policies.id", ondelete="SET NULL" - ) - dynamic_id: str | None = Field( - default=None, foreign_key="dynamics.id", ondelete="SET NULL" - ) - dataset_id: str = Field(foreign_key="datasets.id", ondelete="CASCADE") - model_id: str = Field(foreign_key="models.id", ondelete="CASCADE") - model_version_id: str | None = Field( - default=None, foreign_key="model_versions.id", ondelete="SET NULL" - ) - - result: bytes | None = Field(default=None) - - @classmethod - def convert_from_model(cls, model: Simulation, database: "Database" = None) -> "SimulationTable": - """Convert a Simulation instance to a SimulationTable instance. - - Args: - model: The Simulation instance to convert - database: The database instance for persisting foreign objects if needed - - Returns: - A SimulationTable instance - """ - # Ensure all foreign objects are persisted if database is provided - if database: - if model.policy: - database.set(model.policy, commit=False) - if model.dynamic: - database.set(model.dynamic, commit=False) - if model.dataset: - database.set(model.dataset, commit=False) - if model.model: - database.set(model.model, commit=False) - if model.model_version: - database.set(model.model_version, commit=False) - - sim_table = cls( - id=model.id, - created_at=model.created_at, - updated_at=model.updated_at, - policy_id=model.policy.id if model.policy else None, - dynamic_id=model.dynamic.id if model.dynamic else None, - dataset_id=model.dataset.id if model.dataset else None, - model_id=model.model.id if model.model else None, - model_version_id=model.model_version.id if model.model_version else None, - result=compress_data(model.result) if model.result else None, - ) - - # Handle nested aggregates if database is provided - if database and model.aggregates: - from .aggregate import AggregateTable - from sqlmodel import select - - # First ensure the simulation table is saved to the database - # This is necessary so the foreign key constraint is satisfied - # Check if it already exists - existing_sim = database.session.exec( - select(SimulationTable).where(SimulationTable.id == model.id) - ).first() - - if not existing_sim: - database.session.add(sim_table) - database.session.flush() - - # Track which aggregate IDs we want to keep - desired_agg_ids = {agg.id for agg in model.aggregates} - - # Delete only aggregates linked to this simulation that are NOT in the new list - existing_aggs = database.session.exec( - select(AggregateTable).where(AggregateTable.simulation_id == model.id) - ).all() - for agg in existing_aggs: - if agg.id not in desired_agg_ids: - database.session.delete(agg) - - # Now save/update the aggregates - for aggregate in model.aggregates: - # Check if this aggregate already exists in the database - existing_agg = database.session.exec( - select(AggregateTable).where(AggregateTable.id == aggregate.id) - ).first() - - if existing_agg: - # Update existing aggregate - agg_table = AggregateTable.convert_from_model(aggregate, database) - existing_agg.simulation_id = agg_table.simulation_id - existing_agg.entity = agg_table.entity - existing_agg.variable_name = agg_table.variable_name - existing_agg.year = agg_table.year - existing_agg.filter_variable_name = agg_table.filter_variable_name - existing_agg.filter_variable_value = agg_table.filter_variable_value - existing_agg.filter_variable_leq = agg_table.filter_variable_leq - existing_agg.filter_variable_geq = agg_table.filter_variable_geq - existing_agg.aggregate_function = agg_table.aggregate_function - existing_agg.value = agg_table.value - else: - # Create new aggregate - agg_table = AggregateTable.convert_from_model(aggregate, database) - database.session.add(agg_table) - database.session.flush() - - return sim_table - - def convert_to_model(self, database: "Database" = None) -> Simulation: - """Convert this SimulationTable instance to a Simulation instance. - - Args: - database: The database instance for resolving foreign keys - - Returns: - A Simulation instance - """ - from sqlmodel import select - - from .model_version_table import ModelVersionTable - from .policy_table import PolicyTable - from .dataset_table import DatasetTable - from .model_table import ModelTable - from .dynamic_table import DynamicTable - - # Resolve all foreign keys - policy = None - dynamic = None - dataset = None - model = None - model_version = None - - if database: - if self.policy_id: - policy_table = database.session.exec( - select(PolicyTable).where(PolicyTable.id == self.policy_id) - ).first() - if policy_table: - policy = policy_table.convert_to_model(database) - - if self.dynamic_id: - try: - dynamic_table = database.session.exec( - select(DynamicTable).where(DynamicTable.id == self.dynamic_id) - ).first() - if dynamic_table: - dynamic = dynamic_table.convert_to_model(database) - except: - # Dynamic table might not be defined yet - dynamic = database.get(Dynamic, id=self.dynamic_id) - - if self.dataset_id: - dataset_table = database.session.exec( - select(DatasetTable).where(DatasetTable.id == self.dataset_id) - ).first() - if dataset_table: - dataset = dataset_table.convert_to_model(database) - - if self.model_id: - model_table = database.session.exec( - select(ModelTable).where(ModelTable.id == self.model_id) - ).first() - if model_table: - model = model_table.convert_to_model(database) - - if self.model_version_id: - version_table = database.session.exec( - select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) - ).first() - if version_table: - model_version = version_table.convert_to_model(database) - - # Load aggregates - aggregates = [] - if database: - from .aggregate import AggregateTable - from sqlmodel import select - - agg_tables = database.session.exec( - select(AggregateTable).where(AggregateTable.simulation_id == self.id) - ).all() - - for agg_table in agg_tables: - # Don't pass database to avoid circular reference issues - # The simulation reference will be set separately - agg_model = agg_table.convert_to_model(None) - aggregates.append(agg_model) - - return Simulation( - id=self.id, - created_at=self.created_at, - updated_at=self.updated_at, - policy=policy, - dynamic=dynamic, - dataset=dataset, - model=model, - model_version=model_version, - result=decompress_data(self.result) if self.result else None, - aggregates=aggregates, - ) - - -simulation_table_link = TableLink( - model_cls=Simulation, - table_cls=SimulationTable, -) diff --git a/build/lib/policyengine/database/table_mixin.py b/build/lib/policyengine/database/table_mixin.py deleted file mode 100644 index a29cdeb6..00000000 --- a/build/lib/policyengine/database/table_mixin.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar - -from pydantic import BaseModel -from sqlmodel import SQLModel - -if TYPE_CHECKING: - from .database import Database - -T = TypeVar("T", bound=BaseModel) - - -class TableConversionMixin: - """Mixin class for SQLModel tables to provide conversion methods between table instances and Pydantic models.""" - - _model_cls: ClassVar[type[BaseModel]] = None - _foreign_key_fields: ClassVar[dict[str, type[BaseModel]]] = {} - - @classmethod - def convert_from_model(cls, model: BaseModel, database: "Database" = None) -> SQLModel: - """Convert a Pydantic model instance to a table instance, resolving foreign objects to IDs. - - Args: - model: The Pydantic model instance to convert - database: The database instance for resolving foreign objects (optional) - - Returns: - An instance of the SQLModel table class - """ - data = {} - - for field_name in cls.__annotations__.keys(): - # Check if this field is a foreign key that needs resolution - if field_name in cls._foreign_key_fields: - # Extract ID from the nested object - nested_obj = getattr(model, field_name.replace("_id", ""), None) - if nested_obj: - # If we need to ensure the foreign object exists in DB - if database: - database.set(nested_obj, commit=False) - data[field_name] = nested_obj.id if hasattr(nested_obj, "id") else None - else: - data[field_name] = None - elif hasattr(model, field_name): - # Direct field mapping - data[field_name] = getattr(model, field_name) - - return cls(**data) - - @classmethod - def convert_to_model(cls, table_instance: SQLModel, database: "Database" = None) -> BaseModel: - """Convert a table instance to a Pydantic model, resolving foreign key IDs to objects. - - Args: - table_instance: The SQLModel table instance to convert - database: The database instance for resolving foreign keys (required if foreign keys exist) - - Returns: - An instance of the Pydantic model class - """ - if cls._model_cls is None: - raise ValueError(f"Model class not set for {cls.__name__}") - - data = {} - - for field_name in cls._model_cls.__annotations__.keys(): - # Check if we need to resolve a foreign key - fk_field = f"{field_name}_id" - if fk_field in cls._foreign_key_fields and database: - # Resolve the foreign key to an object - fk_id = getattr(table_instance, fk_field, None) - if fk_id: - foreign_model_cls = cls._foreign_key_fields[fk_field] - data[field_name] = database.get(foreign_model_cls, id=fk_id) - else: - data[field_name] = None - elif hasattr(table_instance, field_name): - # Direct field mapping - data[field_name] = getattr(table_instance, field_name) - - return cls._model_cls(**data) \ No newline at end of file diff --git a/build/lib/policyengine/database/user_table.py b/build/lib/policyengine/database/user_table.py deleted file mode 100644 index d663ac8f..00000000 --- a/build/lib/policyengine/database/user_table.py +++ /dev/null @@ -1,57 +0,0 @@ -import uuid -from datetime import datetime - -from sqlmodel import Field, SQLModel -from typing import TYPE_CHECKING - -from policyengine.models.user import User - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class UserTable(SQLModel, table=True, extend_existing=True): - __tablename__ = "users" - - id: str = Field( - primary_key=True, default_factory=lambda: str(uuid.uuid4()) - ) - username: str = Field(nullable=False, unique=True) - first_name: str | None = Field(default=None) - last_name: str | None = Field(default=None) - email: str | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.utcnow) - updated_at: datetime = Field(default_factory=datetime.utcnow) - - @classmethod - def convert_from_model(cls, model: User, database: "Database" = None) -> "UserTable": - """Convert a User instance to a UserTable instance.""" - return cls( - id=model.id, - username=model.username, - first_name=model.first_name, - last_name=model.last_name, - email=model.email, - created_at=model.created_at, - updated_at=model.updated_at, - ) - - def convert_to_model(self, database: "Database" = None) -> User: - """Convert this UserTable instance to a User instance.""" - return User( - id=self.id, - username=self.username, - first_name=self.first_name, - last_name=self.last_name, - email=self.email, - created_at=self.created_at, - updated_at=self.updated_at, - ) - - -user_table_link = TableLink( - model_cls=User, - table_cls=UserTable, -) diff --git a/build/lib/policyengine/database/versioned_dataset_table.py b/build/lib/policyengine/database/versioned_dataset_table.py deleted file mode 100644 index 4e1524c9..00000000 --- a/build/lib/policyengine/database/versioned_dataset_table.py +++ /dev/null @@ -1,45 +0,0 @@ -from uuid import uuid4 - -from sqlmodel import Field, SQLModel -from typing import TYPE_CHECKING - -from policyengine.models import VersionedDataset - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class VersionedDatasetTable(SQLModel, table=True): - __tablename__ = "versioned_datasets" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - name: str = Field(nullable=False) - description: str = Field(nullable=False) - model_id: str | None = Field( - default=None, foreign_key="models.id", ondelete="SET NULL" - ) - - @classmethod - def convert_from_model(cls, model: VersionedDataset, database: "Database" = None) -> "VersionedDatasetTable": - """Convert a VersionedDataset instance to a VersionedDatasetTable instance.""" - return cls( - id=model.id, - name=model.name, - description=model.description, - ) - - def convert_to_model(self, database: "Database" = None) -> VersionedDataset: - """Convert this VersionedDatasetTable instance to a VersionedDataset instance.""" - return VersionedDataset( - id=self.id, - name=self.name, - description=self.description, - ) - - -versioned_dataset_table_link = TableLink( - model_cls=VersionedDataset, - table_cls=VersionedDatasetTable, -) diff --git a/build/lib/policyengine/models/__init__.py b/build/lib/policyengine/models/__init__.py deleted file mode 100644 index de5fd8c9..00000000 --- a/build/lib/policyengine/models/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -from .aggregate import Aggregate as Aggregate -from .aggregate import AggregateType as AggregateType -from .aggregate_change import AggregateChange as AggregateChange -from .baseline_parameter_value import ( - BaselineParameterValue as BaselineParameterValue, -) -from .baseline_variable import BaselineVariable as BaselineVariable -from .dataset import Dataset as Dataset -from .dynamic import Dynamic as Dynamic -from .model import Model as Model -from .model_version import ModelVersion as ModelVersion -from .parameter import Parameter as Parameter -from .parameter_value import ParameterValue as ParameterValue -from .policy import Policy as Policy -from .policyengine_uk import ( - policyengine_uk_latest_version as policyengine_uk_latest_version, -) -from .policyengine_uk import ( - policyengine_uk_model as policyengine_uk_model, -) -from .policyengine_us import ( - policyengine_us_latest_version as policyengine_us_latest_version, -) -from .policyengine_us import ( - policyengine_us_model as policyengine_us_model, -) -from .report import Report as Report -from .report_element import ReportElement as ReportElement -from .simulation import Simulation as Simulation -from .user import User as User -from .versioned_dataset import VersionedDataset as VersionedDataset - -# Rebuild models to handle circular references -from .aggregate import Aggregate -from .aggregate_change import AggregateChange -from .simulation import Simulation -Aggregate.model_rebuild() -AggregateChange.model_rebuild() -Simulation.model_rebuild() diff --git a/build/lib/policyengine/models/aggregate.py b/build/lib/policyengine/models/aggregate.py deleted file mode 100644 index 031cad87..00000000 --- a/build/lib/policyengine/models/aggregate.py +++ /dev/null @@ -1,132 +0,0 @@ -from enum import Enum -from typing import TYPE_CHECKING, Literal -from uuid import uuid4 - -import pandas as pd -from microdf import MicroDataFrame -from pydantic import BaseModel, Field - -if TYPE_CHECKING: - from policyengine.models import Simulation - - -class AggregateType(str, Enum): - SUM = "sum" - MEAN = "mean" - MEDIAN = "median" - COUNT = "count" - - -class Aggregate(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - simulation: "Simulation | None" = None - entity: str - variable_name: str - year: int | None = None - filter_variable_name: str | None = None - filter_variable_value: str | None = None - filter_variable_leq: float | None = None - filter_variable_geq: float | None = None - aggregate_function: Literal[ - AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT - ] - reportelement_id: str | None = None - - value: float | None = None - - @staticmethod - def run(aggregates: list["Aggregate"]) -> list["Aggregate"]: - """Process aggregates, handling multiple simulations if necessary.""" - # Group aggregates by simulation - simulation_groups = {} - for agg in aggregates: - sim_id = id(agg.simulation) if agg.simulation else None - if sim_id not in simulation_groups: - simulation_groups[sim_id] = [] - simulation_groups[sim_id].append(agg) - - # Process each simulation group separately - all_results = [] - for sim_id, sim_aggregates in simulation_groups.items(): - if not sim_aggregates: - continue - - # Get the simulation from the first aggregate in this group - simulation = sim_aggregates[0].simulation - if simulation is None: - raise ValueError("Aggregate has no simulation attached") - - # Process this simulation's aggregates - group_results = Aggregate._process_simulation_aggregates( - sim_aggregates, simulation - ) - all_results.extend(group_results) - - return all_results - - @staticmethod - def _process_simulation_aggregates( - aggregates: list["Aggregate"], simulation: "Simulation" - ) -> list["Aggregate"]: - """Process aggregates for a single simulation.""" - results = [] - - tables = simulation.result - # Copy tables to ensure we don't modify original dataframes - tables = {k: v.copy() for k, v in tables.items()} - for table in tables: - tables[table] = pd.DataFrame(tables[table]) - weight_col = f"{table}_weight" - if weight_col in tables[table].columns: - tables[table] = MicroDataFrame( - tables[table], weights=weight_col - ) - - for agg in aggregates: - if agg.entity not in tables: - raise ValueError( - f"Entity {agg.entity} not found in simulation results" - ) - table = tables[agg.entity] - - if agg.variable_name not in table.columns: - raise ValueError( - f"Variable {agg.variable_name} not found in entity {agg.entity}" - ) - - df = table - - if agg.year is None: - agg.year = simulation.dataset.year - - if agg.filter_variable_name is not None: - if agg.filter_variable_name not in df.columns: - raise ValueError( - f"Filter variable {agg.filter_variable_name} not found in entity {agg.entity}" - ) - if agg.filter_variable_value is not None: - df = df[ - df[agg.filter_variable_name] - == agg.filter_variable_value - ] - if agg.filter_variable_leq is not None: - df = df[ - df[agg.filter_variable_name] <= agg.filter_variable_leq - ] - if agg.filter_variable_geq is not None: - df = df[ - df[agg.filter_variable_name] >= agg.filter_variable_geq - ] - - if agg.aggregate_function == AggregateType.SUM: - agg.value = float(df[agg.variable_name].sum()) - elif agg.aggregate_function == AggregateType.MEAN: - agg.value = float(df[agg.variable_name].mean()) - elif agg.aggregate_function == AggregateType.MEDIAN: - agg.value = float(df[agg.variable_name].median()) - elif agg.aggregate_function == AggregateType.COUNT: - agg.value = float((df[agg.variable_name] > 0).sum()) - - results.append(agg) - - return results diff --git a/build/lib/policyengine/models/aggregate_change.py b/build/lib/policyengine/models/aggregate_change.py deleted file mode 100644 index e0a400df..00000000 --- a/build/lib/policyengine/models/aggregate_change.py +++ /dev/null @@ -1,143 +0,0 @@ -from enum import Enum -from typing import TYPE_CHECKING, Literal -from uuid import uuid4 - -import pandas as pd -from microdf import MicroDataFrame -from pydantic import BaseModel, Field - -if TYPE_CHECKING: - from policyengine.models import Simulation - - -class AggregateType(str, Enum): - SUM = "sum" - MEAN = "mean" - MEDIAN = "median" - COUNT = "count" - - -class AggregateChange(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - baseline_simulation: "Simulation | None" = None - comparison_simulation: "Simulation | None" = None - entity: str - variable_name: str - year: int | None = None - filter_variable_name: str | None = None - filter_variable_value: str | None = None - filter_variable_leq: float | None = None - filter_variable_geq: float | None = None - aggregate_function: Literal[ - AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT - ] - reportelement_id: str | None = None - - baseline_value: float | None = None - comparison_value: float | None = None - change: float | None = None - relative_change: float | None = None - - @staticmethod - def run(aggregate_changes: list["AggregateChange"]) -> list["AggregateChange"]: - """Process aggregate changes, handling multiple simulation pairs.""" - results = [] - - for agg_change in aggregate_changes: - if agg_change.baseline_simulation is None: - raise ValueError("AggregateChange has no baseline simulation attached") - if agg_change.comparison_simulation is None: - raise ValueError("AggregateChange has no comparison simulation attached") - - # Compute baseline value - baseline_value = AggregateChange._compute_single_aggregate( - agg_change, agg_change.baseline_simulation - ) - - # Compute comparison value - comparison_value = AggregateChange._compute_single_aggregate( - agg_change, agg_change.comparison_simulation - ) - - # Compute changes - agg_change.baseline_value = baseline_value - agg_change.comparison_value = comparison_value - agg_change.change = comparison_value - baseline_value - - # Compute relative change (avoiding division by zero) - if baseline_value != 0: - agg_change.relative_change = (comparison_value - baseline_value) / abs(baseline_value) - else: - agg_change.relative_change = None if comparison_value == 0 else float('inf') - - results.append(agg_change) - - return results - - @staticmethod - def _compute_single_aggregate( - agg_change: "AggregateChange", simulation: "Simulation" - ) -> float: - """Compute aggregate value for a single simulation.""" - tables = simulation.result - # Copy tables to ensure we don't modify original dataframes - tables = {k: v.copy() for k, v in tables.items()} - - for table in tables: - tables[table] = pd.DataFrame(tables[table]) - weight_col = f"{table}_weight" - if weight_col in tables[table].columns: - tables[table] = MicroDataFrame( - tables[table], weights=weight_col - ) - - if agg_change.entity not in tables: - raise ValueError( - f"Entity {agg_change.entity} not found in simulation results" - ) - - table = tables[agg_change.entity] - - if agg_change.variable_name not in table.columns: - raise ValueError( - f"Variable {agg_change.variable_name} not found in entity {agg_change.entity}" - ) - - df = table - - if agg_change.year is None: - agg_change.year = simulation.dataset.year - - # Apply filters - if agg_change.filter_variable_name is not None: - if agg_change.filter_variable_name not in df.columns: - raise ValueError( - f"Filter variable {agg_change.filter_variable_name} not found in entity {agg_change.entity}" - ) - if agg_change.filter_variable_value is not None: - df = df[ - df[agg_change.filter_variable_name] - == agg_change.filter_variable_value - ] - if agg_change.filter_variable_leq is not None: - df = df[ - df[agg_change.filter_variable_name] <= agg_change.filter_variable_leq - ] - if agg_change.filter_variable_geq is not None: - df = df[ - df[agg_change.filter_variable_name] >= agg_change.filter_variable_geq - ] - - # Compute aggregate - if agg_change.aggregate_function == AggregateType.SUM: - value = float(df[agg_change.variable_name].sum()) - elif agg_change.aggregate_function == AggregateType.MEAN: - value = float(df[agg_change.variable_name].mean()) - elif agg_change.aggregate_function == AggregateType.MEDIAN: - value = float(df[agg_change.variable_name].median()) - elif agg_change.aggregate_function == AggregateType.COUNT: - value = float((df[agg_change.variable_name] > 0).sum()) - else: - raise ValueError(f"Unknown aggregate function: {agg_change.aggregate_function}") - - return value \ No newline at end of file diff --git a/build/lib/policyengine/models/baseline_parameter_value.py b/build/lib/policyengine/models/baseline_parameter_value.py deleted file mode 100644 index 8afb6e22..00000000 --- a/build/lib/policyengine/models/baseline_parameter_value.py +++ /dev/null @@ -1,16 +0,0 @@ -from datetime import datetime -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .model_version import ModelVersion -from .parameter import Parameter - - -class BaselineParameterValue(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - parameter: Parameter - model_version: ModelVersion - value: float | int | str | bool | list | None = None - start_date: datetime - end_date: datetime | None = None diff --git a/build/lib/policyengine/models/baseline_variable.py b/build/lib/policyengine/models/baseline_variable.py deleted file mode 100644 index b0e739b1..00000000 --- a/build/lib/policyengine/models/baseline_variable.py +++ /dev/null @@ -1,12 +0,0 @@ -from pydantic import BaseModel - -from .model_version import ModelVersion - - -class BaselineVariable(BaseModel): - id: str - model_version: ModelVersion - entity: str - label: str | None = None - description: str | None = None - data_type: type | None = None diff --git a/build/lib/policyengine/models/dataset.py b/build/lib/policyengine/models/dataset.py deleted file mode 100644 index 59dd626f..00000000 --- a/build/lib/policyengine/models/dataset.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Any -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .model import Model -from .versioned_dataset import VersionedDataset - - -class Dataset(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - name: str - description: str | None = None - version: str | None = None - versioned_dataset: VersionedDataset | None = None - year: int | None = None - data: Any | None = None - model: Model | None = None diff --git a/build/lib/policyengine/models/dynamic.py b/build/lib/policyengine/models/dynamic.py deleted file mode 100644 index 40cf364f..00000000 --- a/build/lib/policyengine/models/dynamic.py +++ /dev/null @@ -1,15 +0,0 @@ -from collections.abc import Callable -from datetime import datetime -from uuid import uuid4 - -from pydantic import BaseModel, Field - - -class Dynamic(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - name: str - description: str | None = None - parameter_values: list[str] = [] - simulation_modifier: Callable | None = None - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) diff --git a/build/lib/policyengine/models/model.py b/build/lib/policyengine/models/model.py deleted file mode 100644 index 89cac9b8..00000000 --- a/build/lib/policyengine/models/model.py +++ /dev/null @@ -1,126 +0,0 @@ -from collections.abc import Callable -from datetime import datetime -from typing import TYPE_CHECKING - -from pydantic import BaseModel - -if TYPE_CHECKING: - from .baseline_parameter_value import BaselineParameterValue - from .baseline_variable import BaselineVariable - from .parameter import Parameter - - -class Model(BaseModel): - id: str - name: str - description: str | None = None - simulation_function: Callable - - def create_seed_objects(self, model_version): - from policyengine_core.parameters import Parameter as CoreParameter - - from .baseline_parameter_value import BaselineParameterValue - from .baseline_variable import BaselineVariable - from .parameter import Parameter - - if self.id == "policyengine_uk": - from policyengine_uk.tax_benefit_system import system - elif self.id == "policyengine_us": - from policyengine_us.system import system - else: - raise ValueError("Unsupported model.") - - parameters = [] - baseline_parameter_values = [] - baseline_variables = [] - seen_parameter_ids = set() - - for parameter in system.parameters.get_descendants(): - # Skip if we've already processed this parameter ID - if parameter.name in seen_parameter_ids: - continue - seen_parameter_ids.add(parameter.name) - param = Parameter( - id=parameter.name, - description=parameter.description, - data_type=None, - model=self, - label=parameter.metadata.get("label"), - unit=parameter.metadata.get("unit"), - ) - parameters.append(param) - if isinstance(parameter, CoreParameter): - values = parameter.values_list[::-1] - param.data_type = type(values[-1].value) - for i in range(len(values)): - value_at_instant = values[i] - instant_str = safe_parse_instant_str( - value_at_instant.instant_str - ) - if i + 1 < len(values): - next_instant_str = safe_parse_instant_str( - values[i + 1].instant_str - ) - else: - next_instant_str = None - baseline_param_value = BaselineParameterValue( - parameter=param, - model_version=model_version, - value=value_at_instant.value, - start_date=instant_str, - end_date=next_instant_str, - ) - baseline_parameter_values.append(baseline_param_value) - - for variable in system.variables.values(): - baseline_variable = BaselineVariable( - id=variable.name, - model_version=model_version, - entity=variable.entity.key, - label=variable.label, - description=variable.documentation, - data_type=variable.value_type, - ) - baseline_variables.append(baseline_variable) - - return SeedObjects( - parameters=parameters, - baseline_parameter_values=baseline_parameter_values, - baseline_variables=baseline_variables, - ) - - -def safe_parse_instant_str(instant_str: str) -> datetime: - if instant_str == "0000-01-01": - return datetime(1, 1, 1) - else: - try: - return datetime.strptime(instant_str, "%Y-%m-%d") - except ValueError: - # Handle invalid dates like 2021-06-31 - # Try to parse year and month, then use last valid day - parts = instant_str.split("-") - if len(parts) == 3: - year = int(parts[0]) - month = int(parts[1]) - day = int(parts[2]) - - # Find the last valid day of the month - import calendar - - last_day = calendar.monthrange(year, month)[1] - if day > last_day: - print( - f"Warning: Invalid date {instant_str}, using {year}-{month:02d}-{last_day:02d}" - ) - return datetime(year, month, last_day) - - # If we can't parse it at all, print and raise - print(f"Error: Cannot parse date {instant_str}") - raise - - -class SeedObjects(BaseModel): - parameters: list["Parameter"] - baseline_parameter_values: list["BaselineParameterValue"] - baseline_variables: list["BaselineVariable"] diff --git a/build/lib/policyengine/models/parameter.py b/build/lib/policyengine/models/parameter.py deleted file mode 100644 index ec7ef7be..00000000 --- a/build/lib/policyengine/models/parameter.py +++ /dev/null @@ -1,14 +0,0 @@ -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .model import Model - - -class Parameter(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - description: str | None = None - data_type: type | None = None - model: Model | None = None - label: str | None = None - unit: str | None = None diff --git a/build/lib/policyengine/models/parameter_value.py b/build/lib/policyengine/models/parameter_value.py deleted file mode 100644 index a7867557..00000000 --- a/build/lib/policyengine/models/parameter_value.py +++ /dev/null @@ -1,14 +0,0 @@ -from datetime import datetime -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .parameter import Parameter - - -class ParameterValue(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - parameter: Parameter - value: float | int | str | bool | list | None = None - start_date: datetime - end_date: datetime | None = None diff --git a/build/lib/policyengine/models/policy.py b/build/lib/policyengine/models/policy.py deleted file mode 100644 index 20587d85..00000000 --- a/build/lib/policyengine/models/policy.py +++ /dev/null @@ -1,17 +0,0 @@ -from collections.abc import Callable -from datetime import datetime -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .parameter_value import ParameterValue - - -class Policy(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - name: str - description: str | None = None - parameter_values: list[ParameterValue] = [] - simulation_modifier: Callable | None = None - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) diff --git a/build/lib/policyengine/models/policyengine_uk.py b/build/lib/policyengine/models/policyengine_uk.py deleted file mode 100644 index 5b97ccfb..00000000 --- a/build/lib/policyengine/models/policyengine_uk.py +++ /dev/null @@ -1,113 +0,0 @@ -import importlib.metadata - -import pandas as pd - -from ..models import Dataset, Dynamic, Model, ModelVersion, Policy - - -def run_policyengine_uk( - dataset: "Dataset", - policy: "Policy | None" = None, - dynamic: "Dynamic | None" = None, -) -> dict[str, "pd.DataFrame"]: - data: dict[str, pd.DataFrame] = dataset.data - - from policyengine_uk import Microsimulation - from policyengine_uk.data import UKSingleYearDataset - - pe_input_data = UKSingleYearDataset( - person=data["person"], - benunit=data["benunit"], - household=data["household"], - fiscal_year=dataset.year, - ) - - sim = Microsimulation(dataset=pe_input_data) - sim.default_calculation_period = dataset.year - - def simulation_modifier(sim: Microsimulation): - if policy is not None and len(policy.parameter_values) > 0: - for parameter_value in policy.parameter_values: - sim.tax_benefit_system.parameters.get_child( - parameter_value.parameter.id - ).update( - value=parameter_value.value, - start=parameter_value.start_date.strftime("%Y-%m-%d"), - stop=parameter_value.end_date.strftime("%Y-%m-%d") - if parameter_value.end_date - else None, - ) - - if dynamic is not None and len(dynamic.parameter_values) > 0: - for parameter_value in dynamic.parameter_values: - sim.tax_benefit_system.parameters.get_child( - parameter_value.parameter.id - ).update( - value=parameter_value.value, - start=parameter_value.start_date.strftime("%Y-%m-%d"), - stop=parameter_value.end_date.strftime("%Y-%m-%d") - if parameter_value.end_date - else None, - ) - - if dynamic is not None and dynamic.simulation_modifier is not None: - dynamic.simulation_modifier(sim) - if policy is not None and policy.simulation_modifier is not None: - policy.simulation_modifier(sim) - - simulation_modifier(sim) - - output_data = {} - - variable_blacklist = [ # TEMPORARY: we need to fix policyengine-uk to make these only take a long time with non-default parameters set to true. - "is_uc_entitled_baseline", - "income_elasticity_lsr", - "child_benefit_opts_out", - "housing_benefit_baseline_entitlement", - "baseline_ctc_entitlement", - "pre_budget_change_household_tax", - "pre_budget_change_household_net_income", - "is_on_cliff", - "marginal_tax_rate_on_capital_gains", - "relative_capital_gains_mtr_change", - "pre_budget_change_ons_equivalised_income_decile", - "substitution_elasticity", - "marginal_tax_rate", - "cliff_evaluated", - "cliff_gap", - "substitution_elasticity_lsr", - "relative_wage_change", - "relative_income_change", - "pre_budget_change_household_benefits", - ] - - for entity in ["person", "benunit", "household"]: - output_data[entity] = pd.DataFrame() - for variable in sim.tax_benefit_system.variables.values(): - correct_entity = variable.entity.key == entity - if variable.name in variable_blacklist: - continue - if variable.definition_period != "year": - continue - if correct_entity: - output_data[entity][variable.name] = sim.calculate( - variable.name - ).values - output_data[entity] = pd.DataFrame(output_data[entity]) - - return output_data - - -policyengine_uk_model = Model( - id="policyengine_uk", - name="PolicyEngine UK", - description="PolicyEngine's open-source tax-benefit microsimulation model.", - simulation_function=run_policyengine_uk, -) - -# Get policyengine-uk version - -policyengine_uk_latest_version = ModelVersion( - model=policyengine_uk_model, - version=importlib.metadata.distribution("policyengine_uk").version, -) diff --git a/build/lib/policyengine/models/policyengine_us.py b/build/lib/policyengine/models/policyengine_us.py deleted file mode 100644 index 9e2eeb7d..00000000 --- a/build/lib/policyengine/models/policyengine_us.py +++ /dev/null @@ -1,115 +0,0 @@ -import importlib.metadata - -import pandas as pd - -from ..models import Dataset, Dynamic, Model, ModelVersion, Policy - - -def run_policyengine_us( - dataset: "Dataset", - policy: "Policy | None" = None, - dynamic: "Dynamic | None" = None, -) -> dict[str, "pd.DataFrame"]: - data: dict[str, pd.DataFrame] = dataset.data - - person_df = pd.DataFrame() - - for table_name, table in data.items(): - if table_name == "person": - for col in table.columns: - person_df[f"{col}__{dataset.year}"] = table[col].values - else: - foreign_key = data["person"][f"person_{table_name}_id"] - primary_key = data[table_name][f"{table_name}_id"] - - projected = table.set_index(primary_key).loc[foreign_key] - - for col in projected.columns: - person_df[f"{col}__{dataset.year}"] = projected[col].values - - from policyengine_us import Microsimulation - - sim = Microsimulation(dataset=person_df) - sim.default_calculation_period = dataset.year - - def simulation_modifier(sim: Microsimulation): - if policy is not None and len(policy.parameter_values) > 0: - for parameter_value in policy.parameter_values: - sim.tax_benefit_system.parameters.get_child( - parameter_value.parameter.id - ).update( - parameter_value.value, - start=parameter_value.start_date.strftime("%Y-%m-%d"), - stop=parameter_value.end_date.strftime("%Y-%m-%d") - if parameter_value.end_date - else None, - ) - - if dynamic is not None and len(dynamic.parameter_values) > 0: - for parameter_value in dynamic.parameter_values: - sim.tax_benefit_system.parameters.get_child( - parameter_value.parameter.id - ).update( - parameter_value.value, - start=parameter_value.start_date.strftime("%Y-%m-%d"), - stop=parameter_value.end_date.strftime("%Y-%m-%d") - if parameter_value.end_date - else None, - ) - - if dynamic is not None and dynamic.simulation_modifier is not None: - dynamic.simulation_modifier(sim) - if policy is not None and policy.simulation_modifier is not None: - policy.simulation_modifier(sim) - - simulation_modifier(sim) - - # Skip reforms for now - - output_data = {} - - variable_whitelist = [ - "household_net_income", - ] - - for variable in variable_whitelist: - sim.calculate(variable) - - for entity in [ - "person", - "marital_unit", - "family", - "tax_unit", - "spm_unit", - "household", - ]: - output_data[entity] = pd.DataFrame() - for variable in sim.tax_benefit_system.variables.values(): - correct_entity = variable.entity.key == entity - if str(dataset.year) not in list( - map(str, sim.get_holder(variable.name).get_known_periods()) - ): - continue - if variable.definition_period != "year": - continue - if not correct_entity: - continue - output_data[entity][variable.name] = sim.calculate(variable.name).values - - return output_data - - -policyengine_us_model = Model( - id="policyengine_us", - name="PolicyEngine US", - description="PolicyEngine's open-source tax-benefit microsimulation model.", - simulation_function=run_policyengine_us, -) - -# Get policyengine-uk version - - -policyengine_us_latest_version = ModelVersion( - model=policyengine_us_model, - version=importlib.metadata.distribution("policyengine_us").version, -) diff --git a/build/lib/policyengine/models/report.py b/build/lib/policyengine/models/report.py deleted file mode 100644 index 2ae0cd3b..00000000 --- a/build/lib/policyengine/models/report.py +++ /dev/null @@ -1,20 +0,0 @@ -import uuid -from datetime import datetime -from typing import TYPE_CHECKING, ForwardRef - -from pydantic import BaseModel, Field - -if TYPE_CHECKING: - from .report_element import ReportElement - - -class Report(BaseModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4())) - label: str - created_at: datetime | None = None - elements: list[ForwardRef("ReportElement")] = Field(default_factory=list) - - -# Import after class definition to avoid circular import -from .report_element import ReportElement -Report.model_rebuild() diff --git a/build/lib/policyengine/models/report_element.py b/build/lib/policyengine/models/report_element.py deleted file mode 100644 index 180ec26c..00000000 --- a/build/lib/policyengine/models/report_element.py +++ /dev/null @@ -1,38 +0,0 @@ -import uuid -from datetime import datetime -from typing import Literal - -from pydantic import BaseModel, Field - - -class ReportElement(BaseModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4())) - label: str - type: Literal["chart", "markdown"] - - # Data source - data_table: Literal["aggregates", "aggregate_changes"] | None = None # Which table to pull from - - # Chart configuration - chart_type: ( - Literal["bar", "line", "scatter", "area", "pie", "histogram"] | None - ) = None - x_axis_variable: str | None = None # Column name from the table - y_axis_variable: str | None = None # Column name from the table - group_by: str | None = None # Column to group/split series by - color_by: str | None = None # Column for color mapping - size_by: str | None = None # Column for size mapping (bubble charts) - - # Markdown specific - markdown_content: str | None = None - - # Metadata - report_id: str | None = None - user_id: str | None = None - model_version_id: str | None = None - position: int | None = None - visible: bool | None = True - custom_config: dict | None = None # Additional chart-specific config - report_element_metadata: dict | None = None # General metadata field for flexible data storage - created_at: datetime | None = None - updated_at: datetime | None = None diff --git a/build/lib/policyengine/models/simulation.py b/build/lib/policyengine/models/simulation.py deleted file mode 100644 index a6ed7a5a..00000000 --- a/build/lib/policyengine/models/simulation.py +++ /dev/null @@ -1,35 +0,0 @@ -from datetime import datetime -from typing import Any -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .dataset import Dataset -from .dynamic import Dynamic -from .model import Model -from .model_version import ModelVersion -from .policy import Policy - - -class Simulation(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - policy: Policy | None = None - dynamic: Dynamic | None = None - dataset: Dataset - - model: Model - model_version: ModelVersion - result: Any | None = None - aggregates: list = Field(default_factory=list) # Will be list[Aggregate] but avoid circular import - - def run(self): - self.result = self.model.simulation_function( - dataset=self.dataset, - policy=self.policy, - dynamic=self.dynamic, - ) - self.updated_at = datetime.now() - return self.result diff --git a/build/lib/policyengine/models/user.py b/build/lib/policyengine/models/user.py deleted file mode 100644 index dee924e1..00000000 --- a/build/lib/policyengine/models/user.py +++ /dev/null @@ -1,14 +0,0 @@ -import uuid -from datetime import datetime - -from pydantic import BaseModel, Field - - -class User(BaseModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4())) - username: str - first_name: str | None = None - last_name: str | None = None - email: str | None = None - created_at: datetime | None = None - updated_at: datetime | None = None diff --git a/build/lib/policyengine/models/versioned_dataset.py b/build/lib/policyengine/models/versioned_dataset.py deleted file mode 100644 index 2f5e14f7..00000000 --- a/build/lib/policyengine/models/versioned_dataset.py +++ /dev/null @@ -1,12 +0,0 @@ -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .model import Model - - -class VersionedDataset(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - name: str - description: str - model: Model | None = None diff --git a/build/lib/policyengine/utils/charts.py b/build/lib/policyengine/utils/charts.py deleted file mode 100644 index 0cee7048..00000000 --- a/build/lib/policyengine/utils/charts.py +++ /dev/null @@ -1,286 +0,0 @@ -"""Chart formatting utilities for PolicyEngine.""" - -import plotly.graph_objects as go -from IPython.display import HTML - -COLOUR_SCHEMES = { - "teal": { - "primary": "#319795", - "secondary": "#38B2AC", - "tertiary": "#4FD1C5", - "light": "#81E6D9", - "lighter": "#B2F5EA", - "lightest": "#E6FFFA", - "dark": "#2C7A7B", - "darker": "#285E61", - "darkest": "#234E52", - }, - "blue": { - "primary": "#0EA5E9", - "secondary": "#0284C7", - "tertiary": "#38BDF8", - "light": "#7DD3FC", - "lighter": "#BAE6FD", - "lightest": "#E0F2FE", - "dark": "#026AA2", - "darker": "#075985", - "darkest": "#0C4A6E", - }, - "gray": { - "primary": "#6B7280", - "secondary": "#9CA3AF", - "tertiary": "#D1D5DB", - "light": "#E2E8F0", - "lighter": "#F2F4F7", - "lightest": "#F9FAFB", - "dark": "#4B5563", - "darker": "#344054", - "darkest": "#101828", - }, -} - -DEFAULT_COLOURS = [ - COLOUR_SCHEMES["teal"]["primary"], - COLOUR_SCHEMES["blue"]["primary"], - COLOUR_SCHEMES["teal"]["secondary"], - COLOUR_SCHEMES["blue"]["secondary"], - COLOUR_SCHEMES["teal"]["tertiary"], - COLOUR_SCHEMES["blue"]["tertiary"], - COLOUR_SCHEMES["gray"]["dark"], - COLOUR_SCHEMES["teal"]["dark"], -] - - -def add_fonts() -> HTML: - """Return HTML to add Google Fonts for Roboto and Roboto Mono.""" - return HTML(""" - - - - """) - - -def format_figure( - fig: go.Figure, - title: str | None = None, - x_title: str | None = None, - y_title: str | None = None, - colour_scheme: str = "teal", - show_grid: bool = True, - show_legend: bool = True, - height: int | None = None, - width: int | None = None, -) -> go.Figure: - """Apply consistent formatting to a Plotly figure. - - Args: - fig: The Plotly figure to format - title: Optional title for the chart - x_title: Optional x-axis title - y_title: Optional y-axis title - colour_scheme: Colour scheme name (teal, blue, gray) - show_grid: Whether to show gridlines - show_legend: Whether to show the legend - height: Optional figure height in pixels - width: Optional figure width in pixels - - Returns: - The formatted figure - """ - - colours = COLOUR_SCHEMES.get(colour_scheme, COLOUR_SCHEMES["teal"]) - - # Update traces with colour scheme - for i, trace in enumerate(fig.data): - if hasattr(trace, "marker"): - trace.marker.color = DEFAULT_COLOURS[i % len(DEFAULT_COLOURS)] - if hasattr(trace, "line"): - trace.line.color = DEFAULT_COLOURS[i % len(DEFAULT_COLOURS)] - trace.line.width = 2 - - # Base layout settings - layout_updates = { - "font": { - "family": "Roboto, sans-serif", - "size": 14, - "color": COLOUR_SCHEMES["gray"]["darkest"], - }, - "plot_bgcolor": "white", - "paper_bgcolor": "white", - "showlegend": show_legend, - "hovermode": "x unified", - "hoverlabel": { - "bgcolor": "white", - "font": {"family": "Roboto Mono, monospace", "size": 12}, - "bordercolor": colours["light"], - }, - } - - # Add title if provided - if title: - layout_updates["title"] = { - "text": title, - "font": { - "family": "Roboto, sans-serif", - "size": 20, - "color": COLOUR_SCHEMES["gray"]["darkest"], - "weight": 500, - }, - } - - # Configure axes - axis_config = { - "showgrid": show_grid, - "gridcolor": COLOUR_SCHEMES["gray"]["light"], - "gridwidth": 1, - "zeroline": True, - "zerolinecolor": COLOUR_SCHEMES["gray"]["lighter"], - "zerolinewidth": 1, - "tickfont": { - "family": "Roboto Mono, monospace", - "size": 11, - "color": COLOUR_SCHEMES["gray"]["primary"], - }, - "titlefont": { - "family": "Roboto, sans-serif", - "size": 14, - "color": COLOUR_SCHEMES["gray"]["dark"], - }, - "linecolor": COLOUR_SCHEMES["gray"]["light"], - "linewidth": 1, - "showline": True, - "mirror": False, - } - - layout_updates["xaxis"] = axis_config.copy() - layout_updates["yaxis"] = axis_config.copy() - - if x_title: - layout_updates["xaxis"]["title"] = x_title - if y_title: - layout_updates["yaxis"]["title"] = y_title - - layout_updates["showlegend"] = len(fig.data) > 1 and show_legend - - # Set dimensions if provided - if height: - layout_updates["height"] = height - if width: - layout_updates["width"] = width - - fig.update_layout(**layout_updates) - - fig.update_xaxes(title_font_color=COLOUR_SCHEMES["gray"]["primary"]) - fig.update_yaxes(title_font_color=COLOUR_SCHEMES["gray"]["primary"]) - - # Add text annotations to bars in bar charts - if any(isinstance(trace, go.Bar) for trace in fig.data): - for trace in fig.data: - if isinstance(trace, go.Bar): - trace.texttemplate = "%{y:,.0f}" - trace.textposition = "outside" - trace.textfont = { - "family": "Roboto Mono, monospace", - "size": 11, - "color": COLOUR_SCHEMES["gray"]["primary"], - } - - return fig - - -def create_bar_chart( - data: dict[str, list], - x: str, - y: str, - title: str | None = None, - colour_scheme: str = "teal", - **kwargs, -) -> go.Figure: - """Create a formatted bar chart. - - Args: - data: Dictionary with data for the chart - x: Column name for x-axis - y: Column name for y-axis - title: Optional chart title - colour_scheme: Colour scheme to use - **kwargs: Additional arguments for format_figure - - Returns: - Formatted bar chart figure - """ - fig = go.Figure( - data=[ - go.Bar( - x=data[x], - y=data[y], - marker_color=COLOUR_SCHEMES[colour_scheme]["primary"], - marker_line_color=COLOUR_SCHEMES[colour_scheme]["dark"], - marker_line_width=1, - hovertemplate=f"{x}: " - + "%{x}
" - + f"{y}: " - + "%{y:,.0f}", - ) - ] - ) - - return format_figure( - fig, - title=title, - x_title=x, - y_title=y, - colour_scheme=colour_scheme, - **kwargs, - ) - - -def create_line_chart( - data: dict[str, list], - x: str, - y: str | list[str], - title: str | None = None, - colour_scheme: str = "teal", - **kwargs, -) -> go.Figure: - """Create a formatted line chart. - - Args: - data: Dictionary with data for the chart - x: Column name for x-axis - y: Column name(s) for y-axis (can be a list for multiple lines) - title: Optional chart title - colour_scheme: Colour scheme to use - **kwargs: Additional arguments for format_figure - - Returns: - Formatted line chart figure - """ - traces = [] - y_columns = y if isinstance(y, list) else [y] - - for i, y_col in enumerate(y_columns): - traces.append( - go.Scatter( - x=data[x], - y=data[y_col], - mode="lines+markers", - name=y_col, - line=dict( - color=DEFAULT_COLOURS[i % len(DEFAULT_COLOURS)], width=2 - ), - marker=dict(size=6), - hovertemplate=f"{y_col}: " + "%{y:,.0f}", - ) - ) - - fig = go.Figure(data=traces) - - return format_figure( - fig, - title=title, - x_title=x, - y_title=y_columns[0] if len(y_columns) == 1 else None, - colour_scheme=colour_scheme, - **kwargs, - ) diff --git a/build/lib/policyengine/utils/compress.py b/build/lib/policyengine/utils/compress.py deleted file mode 100644 index 19180e2a..00000000 --- a/build/lib/policyengine/utils/compress.py +++ /dev/null @@ -1,20 +0,0 @@ -import pickle -from typing import Any - -import blosc - - -def compress_data(data: Any) -> bytes: - """Compress data using blosc after pickling.""" - pickled_data = pickle.dumps(data) - compressed_data = blosc.compress( - pickled_data, typesize=8, cname="zstd", clevel=9, shuffle=blosc.SHUFFLE - ) - return compressed_data - - -def decompress_data(compressed_data: bytes) -> Any: - """Decompress data using blosc and then unpickle.""" - decompressed_data = blosc.decompress(compressed_data) - data = pickle.loads(decompressed_data) - return data diff --git a/build/lib/policyengine/utils/datasets.py b/build/lib/policyengine/utils/datasets.py deleted file mode 100644 index 02090e11..00000000 --- a/build/lib/policyengine/utils/datasets.py +++ /dev/null @@ -1,71 +0,0 @@ -import pandas as pd - -from policyengine.models import Dataset - - -def create_uk_dataset( - dataset: str = "enhanced_frs_2023_24.h5", - year: int = 2029, -): - from policyengine_uk import Microsimulation - - from policyengine.models.policyengine_uk import policyengine_uk_model - - sim = Microsimulation( - dataset="hf://policyengine/policyengine-uk-data/" + dataset - ) - sim.default_calculation_period = year - - tables = { - "person": pd.DataFrame(sim.dataset[year].person), - "benunit": pd.DataFrame(sim.dataset[year].benunit), - "household": pd.DataFrame(sim.dataset[year].household), - } - - return Dataset( - id="uk", - name="UK", - description="A representative dataset for the UK, based on the Family Resources Survey.", - year=year, - model=policyengine_uk_model, - data=tables, - ) - - -def create_us_dataset( - dataset: str = "enhanced_cps_2024.h5", - year: int = 2024, -): - from policyengine_us import Microsimulation - - from policyengine.models.policyengine_us import policyengine_us_model - - sim = Microsimulation( - dataset="hf://policyengine/policyengine-us-data/" + dataset - ) - sim.default_calculation_period = year - - known_variables = sim.input_variables - - tables = { - "person": pd.DataFrame(), - "marital_unit": pd.DataFrame(), - "tax_unit": pd.DataFrame(), - "spm_unit": pd.DataFrame(), - "family": pd.DataFrame(), - "household": pd.DataFrame(), - } - - for variable in known_variables: - entity = sim.tax_benefit_system.variables[variable].entity.key - if variable in sim.tax_benefit_system.variables: - tables[entity][variable] = sim.calculate(variable) - - return Dataset( - id="us", - name="US", - description="A representative dataset for the US, based on the Current Population Survey.", - year=year, - model=policyengine_us_model, - data=tables, - ) diff --git a/src/policyengine/database/__init__.py b/src/policyengine/database/__init__.py deleted file mode 100644 index 69f34d89..00000000 --- a/src/policyengine/database/__init__.py +++ /dev/null @@ -1,62 +0,0 @@ -from .baseline_parameter_value_table import ( - BaselineParameterValueTable, - baseline_parameter_value_table_link, -) -from .baseline_variable_table import ( - BaselineVariableTable, - baseline_variable_table_link, -) -from .database import Database -from .dataset_table import DatasetTable, dataset_table_link -from .dynamic_table import DynamicTable, dynamic_table_link -from .link import TableLink - -# Import all table classes and links -from .model_table import ModelTable, model_table_link -from .model_version_table import ModelVersionTable, model_version_table_link -from .parameter_table import ParameterTable, parameter_table_link -from .parameter_value_table import ( - ParameterValueTable, - parameter_value_table_link, -) -from .policy_table import PolicyTable, policy_table_link -from .simulation_table import SimulationTable, simulation_table_link -from .versioned_dataset_table import ( - VersionedDatasetTable, - versioned_dataset_table_link, -) -from .aggregate import AggregateTable, aggregate_table_link -from .aggregate_change import AggregateChangeTable, aggregate_change_table_link - -__all__ = [ - "Database", - "TableLink", - # Tables - "ModelTable", - "ModelVersionTable", - "DatasetTable", - "VersionedDatasetTable", - "PolicyTable", - "DynamicTable", - "ParameterTable", - "ParameterValueTable", - "BaselineParameterValueTable", - "BaselineVariableTable", - "SimulationTable", - "AggregateTable", - "AggregateChangeTable", - # Links - "model_table_link", - "model_version_table_link", - "dataset_table_link", - "versioned_dataset_table_link", - "policy_table_link", - "dynamic_table_link", - "parameter_table_link", - "parameter_value_table_link", - "baseline_parameter_value_table_link", - "baseline_variable_table_link", - "simulation_table_link", - "aggregate_table_link", - "aggregate_change_table_link", -] diff --git a/src/policyengine/database/aggregate.py b/src/policyengine/database/aggregate.py deleted file mode 100644 index c192605a..00000000 --- a/src/policyengine/database/aggregate.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.database.link import TableLink -from policyengine.models.aggregate import Aggregate -from policyengine.models import Simulation - -if TYPE_CHECKING: - from .database import Database - - -class AggregateTable(SQLModel, table=True): - __tablename__ = "aggregates" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - simulation_id: str = Field( - foreign_key="simulations.id", ondelete="CASCADE" - ) - entity: str - variable_name: str - year: int | None = None - filter_variable_name: str | None = None - filter_variable_value: str | None = None - filter_variable_leq: float | None = None - filter_variable_geq: float | None = None - filter_variable_quantile_leq: float | None = None - filter_variable_quantile_geq: float | None = None - filter_variable_quantile_value: str | None = None - aggregate_function: str - reportelement_id: str | None = None - value: float | None = None - - @classmethod - def convert_from_model(cls, model: Aggregate, database: "Database" = None) -> "AggregateTable": - """Convert an Aggregate instance to an AggregateTable instance. - - Args: - model: The Aggregate instance to convert - database: The database instance for persisting the simulation if needed - - Returns: - An AggregateTable instance - """ - # Don't try to save the simulation here - it's already being saved - # This prevents circular references - - return cls( - id=model.id, - simulation_id=model.simulation.id if model.simulation else None, - entity=model.entity, - variable_name=model.variable_name, - year=model.year, - filter_variable_name=model.filter_variable_name, - filter_variable_value=model.filter_variable_value, - filter_variable_leq=model.filter_variable_leq, - filter_variable_geq=model.filter_variable_geq, - filter_variable_quantile_leq=model.filter_variable_quantile_leq, - filter_variable_quantile_geq=model.filter_variable_quantile_geq, - filter_variable_quantile_value=model.filter_variable_quantile_value, - aggregate_function=model.aggregate_function, - reportelement_id=model.reportelement_id, - value=model.value, - ) - - def convert_to_model(self, database: "Database" = None) -> Aggregate: - """Convert this AggregateTable instance to an Aggregate instance. - - Args: - database: The database instance for resolving the simulation foreign key - - Returns: - An Aggregate instance - """ - from .simulation_table import SimulationTable - from sqlmodel import select - - # Resolve the simulation foreign key - simulation = None - if database and self.simulation_id: - sim_table = database.session.exec( - select(SimulationTable).where(SimulationTable.id == self.simulation_id) - ).first() - if sim_table: - simulation = sim_table.convert_to_model(database) - - return Aggregate( - id=self.id, - simulation=simulation, - entity=self.entity, - variable_name=self.variable_name, - year=self.year, - filter_variable_name=self.filter_variable_name, - filter_variable_value=self.filter_variable_value, - filter_variable_leq=self.filter_variable_leq, - filter_variable_geq=self.filter_variable_geq, - filter_variable_quantile_leq=self.filter_variable_quantile_leq, - filter_variable_quantile_geq=self.filter_variable_quantile_geq, - filter_variable_quantile_value=self.filter_variable_quantile_value, - aggregate_function=self.aggregate_function, - reportelement_id=self.reportelement_id, - value=self.value, - ) - - -aggregate_table_link = TableLink( - model_cls=Aggregate, - table_cls=AggregateTable, -) diff --git a/src/policyengine/database/aggregate_change.py b/src/policyengine/database/aggregate_change.py deleted file mode 100644 index 47fbda05..00000000 --- a/src/policyengine/database/aggregate_change.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.database.link import TableLink -from policyengine.models.aggregate_change import AggregateChange - -if TYPE_CHECKING: - from .database import Database - - -class AggregateChangeTable(SQLModel, table=True): - __tablename__ = "aggregate_changes" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - baseline_simulation_id: str = Field( - foreign_key="simulations.id", ondelete="CASCADE" - ) - comparison_simulation_id: str = Field( - foreign_key="simulations.id", ondelete="CASCADE" - ) - entity: str - variable_name: str - year: int | None = None - filter_variable_name: str | None = None - filter_variable_value: str | None = None - filter_variable_leq: float | None = None - filter_variable_geq: float | None = None - filter_variable_quantile_leq: float | None = None - filter_variable_quantile_geq: float | None = None - filter_variable_quantile_value: str | None = None - aggregate_function: str - reportelement_id: str | None = None - - baseline_value: float | None = None - comparison_value: float | None = None - change: float | None = None - relative_change: float | None = None - - @classmethod - def convert_from_model(cls, model: AggregateChange, database: "Database" = None) -> "AggregateChangeTable": - """Convert an AggregateChange instance to an AggregateChangeTable instance. - - Args: - model: The AggregateChange instance to convert - database: The database instance for persisting the simulations if needed - - Returns: - An AggregateChangeTable instance - """ - return cls( - id=model.id, - baseline_simulation_id=model.baseline_simulation.id if model.baseline_simulation else None, - comparison_simulation_id=model.comparison_simulation.id if model.comparison_simulation else None, - entity=model.entity, - variable_name=model.variable_name, - year=model.year, - filter_variable_name=model.filter_variable_name, - filter_variable_value=model.filter_variable_value, - filter_variable_leq=model.filter_variable_leq, - filter_variable_geq=model.filter_variable_geq, - filter_variable_quantile_leq=model.filter_variable_quantile_leq, - filter_variable_quantile_geq=model.filter_variable_quantile_geq, - filter_variable_quantile_value=model.filter_variable_quantile_value, - aggregate_function=model.aggregate_function, - reportelement_id=model.reportelement_id, - baseline_value=model.baseline_value, - comparison_value=model.comparison_value, - change=model.change, - relative_change=model.relative_change, - ) - - def convert_to_model(self, database: "Database" = None) -> AggregateChange: - """Convert this AggregateChangeTable instance to an AggregateChange instance. - - Args: - database: The database instance for resolving simulation foreign keys - - Returns: - An AggregateChange instance - """ - from .simulation_table import SimulationTable - from sqlmodel import select - - # Resolve the simulation foreign keys - baseline_simulation = None - comparison_simulation = None - - if database: - if self.baseline_simulation_id: - sim_table = database.session.exec( - select(SimulationTable).where(SimulationTable.id == self.baseline_simulation_id) - ).first() - if sim_table: - baseline_simulation = sim_table.convert_to_model(database) - - if self.comparison_simulation_id: - sim_table = database.session.exec( - select(SimulationTable).where(SimulationTable.id == self.comparison_simulation_id) - ).first() - if sim_table: - comparison_simulation = sim_table.convert_to_model(database) - - return AggregateChange( - id=self.id, - baseline_simulation=baseline_simulation, - comparison_simulation=comparison_simulation, - entity=self.entity, - variable_name=self.variable_name, - year=self.year, - filter_variable_name=self.filter_variable_name, - filter_variable_value=self.filter_variable_value, - filter_variable_leq=self.filter_variable_leq, - filter_variable_geq=self.filter_variable_geq, - filter_variable_quantile_leq=self.filter_variable_quantile_leq, - filter_variable_quantile_geq=self.filter_variable_quantile_geq, - filter_variable_quantile_value=self.filter_variable_quantile_value, - aggregate_function=self.aggregate_function, - reportelement_id=self.reportelement_id, - baseline_value=self.baseline_value, - comparison_value=self.comparison_value, - change=self.change, - relative_change=self.relative_change, - ) - - -aggregate_change_table_link = TableLink( - model_cls=AggregateChange, - table_cls=AggregateChangeTable, -) \ No newline at end of file diff --git a/src/policyengine/database/baseline_parameter_value_table.py b/src/policyengine/database/baseline_parameter_value_table.py deleted file mode 100644 index 6485223c..00000000 --- a/src/policyengine/database/baseline_parameter_value_table.py +++ /dev/null @@ -1,112 +0,0 @@ -from datetime import datetime -from typing import Any -from uuid import uuid4 - -from sqlmodel import JSON, Column, Field, SQLModel -from typing import TYPE_CHECKING - -from policyengine.models import ModelVersion, Parameter, BaselineParameterValue - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class BaselineParameterValueTable(SQLModel, table=True): - __tablename__ = "baseline_parameter_values" - __table_args__ = ({"extend_existing": True},) - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - parameter_id: str = Field(nullable=False) # Part of composite foreign key - model_id: str = Field(nullable=False) # Part of composite foreign key - model_version_id: str = Field( - foreign_key="model_versions.id", ondelete="CASCADE" - ) - value: Any | None = Field( - default=None, sa_column=Column(JSON) - ) # JSON field for any type - start_date: datetime = Field(nullable=False) - end_date: datetime | None = Field(default=None) - - @classmethod - def convert_from_model(cls, model: BaselineParameterValue, database: "Database" = None) -> "BaselineParameterValueTable": - """Convert a BaselineParameterValue instance to a BaselineParameterValueTable instance.""" - import math - - # Ensure foreign objects are persisted if database is provided - if database: - if model.parameter: - database.set(model.parameter, commit=False) - if model.model_version: - database.set(model.model_version, commit=False) - - # Handle special float values - value = model.value - if isinstance(value, float): - if math.isinf(value): - value = "Infinity" if value > 0 else "-Infinity" - elif math.isnan(value): - value = "NaN" - - return cls( - id=model.id, - parameter_id=model.parameter.id if model.parameter else None, - model_id=model.parameter.model.id if model.parameter and model.parameter.model else None, - model_version_id=model.model_version.id if model.model_version else None, - value=value, - start_date=model.start_date, - end_date=model.end_date, - ) - - def convert_to_model(self, database: "Database" = None) -> BaselineParameterValue: - """Convert this BaselineParameterValueTable instance to a BaselineParameterValue instance.""" - from .parameter_table import ParameterTable - from .model_version_table import ModelVersionTable - from sqlmodel import select - - # Resolve foreign keys - parameter = None - model_version = None - - if database: - if self.parameter_id and self.model_id: - param_table = database.session.exec( - select(ParameterTable).where( - ParameterTable.id == self.parameter_id, - ParameterTable.model_id == self.model_id - ) - ).first() - if param_table: - parameter = param_table.convert_to_model(database) - - if self.model_version_id: - version_table = database.session.exec( - select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) - ).first() - if version_table: - model_version = version_table.convert_to_model(database) - - # Handle special string values - value = self.value - if value == "Infinity": - value = float("inf") - elif value == "-Infinity": - value = float("-inf") - elif value == "NaN": - value = float("nan") - - return BaselineParameterValue( - id=self.id, - parameter=parameter, - model_version=model_version, - value=value, - start_date=self.start_date, - end_date=self.end_date, - ) - - -baseline_parameter_value_table_link = TableLink( - model_cls=BaselineParameterValue, - table_cls=BaselineParameterValueTable, -) diff --git a/src/policyengine/database/baseline_variable_table.py b/src/policyengine/database/baseline_variable_table.py deleted file mode 100644 index e7773c80..00000000 --- a/src/policyengine/database/baseline_variable_table.py +++ /dev/null @@ -1,81 +0,0 @@ -from sqlmodel import Field, SQLModel -from typing import TYPE_CHECKING - -from policyengine.models import ModelVersion, BaselineVariable - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class BaselineVariableTable(SQLModel, table=True): - __tablename__ = "baseline_variables" - __table_args__ = ({"extend_existing": True},) - - id: str = Field(primary_key=True) # Variable name - model_id: str = Field( - primary_key=True, foreign_key="models.id" - ) # Part of composite key - model_version_id: str = Field( - foreign_key="model_versions.id", ondelete="CASCADE" - ) - entity: str = Field(nullable=False) - label: str | None = Field(default=None) - description: str | None = Field(default=None) - data_type: str | None = Field(default=None) # Data type name - - @classmethod - def convert_from_model(cls, model: BaselineVariable, database: "Database" = None) -> "BaselineVariableTable": - """Convert a BaselineVariable instance to a BaselineVariableTable instance.""" - # Ensure foreign objects are persisted if database is provided - if database and model.model_version: - database.set(model.model_version, commit=False) - - return cls( - id=model.id, - model_id=model.model_version.model.id if model.model_version and model.model_version.model else None, - model_version_id=model.model_version.id if model.model_version else None, - entity=model.entity, - label=model.label, - description=model.description, - data_type=model.data_type.__name__ if model.data_type else None, - ) - - def convert_to_model(self, database: "Database" = None) -> BaselineVariable: - """Convert this BaselineVariableTable instance to a BaselineVariable instance.""" - from .model_version_table import ModelVersionTable - from sqlmodel import select - - # Resolve foreign keys - model_version = None - - if database and self.model_version_id: - version_table = database.session.exec( - select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) - ).first() - if version_table: - model_version = version_table.convert_to_model(database) - - # Convert data_type string back to type - data_type = None - if self.data_type: - try: - data_type = eval(self.data_type) - except: - data_type = None - - return BaselineVariable( - id=self.id, - model_version=model_version, - entity=self.entity, - label=self.label, - description=self.description, - data_type=data_type, - ) - - -baseline_variable_table_link = TableLink( - model_cls=BaselineVariable, - table_cls=BaselineVariableTable, -) diff --git a/src/policyengine/database/database.py b/src/policyengine/database/database.py deleted file mode 100644 index 2eb0fc40..00000000 --- a/src/policyengine/database/database.py +++ /dev/null @@ -1,339 +0,0 @@ -from typing import Any - -from sqlmodel import Session, SQLModel - -from .aggregate import aggregate_table_link -from .baseline_parameter_value_table import baseline_parameter_value_table_link -from .baseline_variable_table import baseline_variable_table_link -from .dataset_table import dataset_table_link -from .dynamic_table import dynamic_table_link -from .link import TableLink - -# Import all table links -from .model_table import model_table_link -from .model_version_table import model_version_table_link -from .parameter_table import parameter_table_link -from .parameter_value_table import parameter_value_table_link -from .policy_table import policy_table_link -from .simulation_table import simulation_table_link -from .versioned_dataset_table import versioned_dataset_table_link - - -class Database: - url: str - _model_table_links: list[TableLink] - - def __init__(self, url: str): - self.url = url - self.engine = self._create_engine() - self.session = Session(self.engine) - - # Initialize instance variable for table links - self._model_table_links = [] - - # Register all table links - for link in [ - model_table_link, - model_version_table_link, - dataset_table_link, - versioned_dataset_table_link, - policy_table_link, - dynamic_table_link, - parameter_table_link, - parameter_value_table_link, - baseline_parameter_value_table_link, - baseline_variable_table_link, - simulation_table_link, - aggregate_table_link, - ]: - self.register_table(link) - - def _create_engine(self): - from sqlmodel import create_engine - - # Configure engine with proper settings for PostgreSQL/Supabase - engine_args = { - "echo": False, - "pool_pre_ping": True, # Verify connections before using - "pool_recycle": 3600, # Recycle connections after 1 hour - } - - # For PostgreSQL, ensure proper connection pooling - if self.url.startswith("postgresql"): - engine_args["pool_size"] = 5 - engine_args["max_overflow"] = 10 - - return create_engine(self.url, **engine_args) - - def create_tables(self): - """Create all database tables.""" - SQLModel.metadata.create_all(self.engine) - - def drop_tables(self): - """Drop all database tables.""" - SQLModel.metadata.drop_all(self.engine) - - def reset(self): - """Drop and recreate all tables.""" - self.drop_tables() - self.create_tables() - - def ensure_anonymous_user(self): - """Deprecated: This method no longer exists as user management has been moved to the API layer.""" - pass - - def __enter__(self): - """Context manager entry - creates a session.""" - self.session = Session(self.engine) - return self.session - - def __exit__(self, exc_type, exc_val, exc_tb): - """Context manager exit - closes the session.""" - if exc_type: - self.session.rollback() - else: - self.session.commit() - self.session.close() - - def register_table(self, link: TableLink): - """Register a table link for use with the database. - - Note: This does NOT create the table. Call create_tables() after - registering all tables to create them in the correct order respecting - foreign key dependencies. - - Args: - link: The TableLink to register - """ - self._model_table_links.append(link) - - def verify_tables_exist(self) -> dict[str, bool]: - """Verify that all registered tables exist in the database. - - Returns: - A dictionary mapping table names to whether they exist - """ - from sqlalchemy import inspect as sql_inspect - - inspector = sql_inspect(self.engine) - existing_tables = set(inspector.get_table_names()) - - results = {} - for link in self._model_table_links: - table_name = link.table_cls.__tablename__ - results[table_name] = table_name in existing_tables - - return results - - def get(self, model_cls: type, **kwargs): - """Get a model instance from the database by its attributes.""" - from sqlmodel import select - - # Find the table class for this model - table_link = next( - ( - link - for link in self._model_table_links - if link.model_cls == model_cls - ), - None, - ) - - if table_link is None: - return None - - # Query the database - statement = select(table_link.table_cls).filter_by(**kwargs) - result = self.session.exec(statement).first() - - if result is None: - return None - - # Use the table's convert_to_model method - return result.convert_to_model(self) - - def set(self, object: Any, commit: bool = True): - """Save or update a model instance in the database.""" - from sqlmodel import select - from sqlalchemy.inspection import inspect - - # Find the table class for this model - table_link = next( - ( - link - for link in self._model_table_links - if link.model_cls is type(object) - ), - None, - ) - - if table_link is None: - return - - # Convert model to table instance - table_obj = table_link.table_cls.convert_from_model(object, self) - - # Get primary key columns - mapper = inspect(table_link.table_cls) - pk_cols = [col.name for col in mapper.primary_key] - - # Build query to check if exists - query = select(table_link.table_cls) - for pk_col in pk_cols: - query = query.where( - getattr(table_link.table_cls, pk_col) == getattr(table_obj, pk_col) - ) - - existing = self.session.exec(query).first() - - if existing: - # Update existing record - for key, value in table_obj.model_dump().items(): - setattr(existing, key, value) - self.session.add(existing) - else: - self.session.add(table_obj) - - if commit: - self.session.commit() - - def register_model_version(self, model_version): - """Register a model version with its model and seed objects. - This replaces all existing parameters, baseline parameter values, - and baseline variables for this model version.""" - # Add or update the model directly to avoid conflicts - from policyengine.utils.compress import compress_data - - from .baseline_parameter_value_table import BaselineParameterValueTable - from .baseline_variable_table import BaselineVariableTable - from .model_table import ModelTable - from .model_version_table import ModelVersionTable - from .parameter_table import ParameterTable - - existing_model = ( - self.session.query(ModelTable) - .filter(ModelTable.id == model_version.model.id) - .first() - ) - if not existing_model: - model_table = ModelTable( - id=model_version.model.id, - name=model_version.model.name, - description=model_version.model.description, - simulation_function=compress_data( - model_version.model.simulation_function - ), - ) - self.session.add(model_table) - self.session.flush() - - # Add or update the model version - existing_version = ( - self.session.query(ModelVersionTable) - .filter(ModelVersionTable.id == model_version.id) - .first() - ) - if not existing_version: - version_table = ModelVersionTable( - id=model_version.id, - model_id=model_version.model.id, - version=model_version.version, - description=model_version.description, - created_at=model_version.created_at, - ) - self.session.add(version_table) - self.session.flush() - - # Get seed objects from the model - seed_objects = model_version.model.create_seed_objects(model_version) - - # Delete ALL existing seed data for this model (not just this version) - # This ensures we start fresh with the new version's data - # Order matters due to foreign key constraints - - # First delete baseline parameter values (they reference parameters) - self.session.query(BaselineParameterValueTable).filter( - BaselineParameterValueTable.model_id == model_version.model.id - ).delete() - - # Then delete baseline variables for this model - self.session.query(BaselineVariableTable).filter( - BaselineVariableTable.model_id == model_version.model.id - ).delete() - - # Finally delete all parameters for this model - self.session.query(ParameterTable).filter( - ParameterTable.model_id == model_version.model.id - ).delete() - - self.session.commit() - - # Add all parameters first - for parameter in seed_objects.parameters: - # We need to add directly to session to avoid the autoflush issue - from .parameter_table import ParameterTable - - param_table = ParameterTable( - id=parameter.id, - model_id=parameter.model.id, # Now required as part of composite key - description=parameter.description, - data_type=parameter.data_type.__name__ - if parameter.data_type - else None, - label=parameter.label, - unit=parameter.unit, - ) - self.session.add(param_table) - - # Flush parameters to database so they exist for foreign key constraints - self.session.flush() - - # Add all baseline parameter values - for baseline_param_value in seed_objects.baseline_parameter_values: - import math - from uuid import uuid4 - - from .baseline_parameter_value_table import ( - BaselineParameterValueTable, - ) - - # Handle special float values that JSON doesn't support - value = baseline_param_value.value - if isinstance(value, float): - if math.isinf(value): - value = "Infinity" if value > 0 else "-Infinity" - elif math.isnan(value): - value = "NaN" - - bpv_table = BaselineParameterValueTable( - id=str(uuid4()), - parameter_id=baseline_param_value.parameter.id, - model_id=baseline_param_value.parameter.model.id, # Add model_id - model_version_id=baseline_param_value.model_version.id, - value=value, - start_date=baseline_param_value.start_date, - end_date=baseline_param_value.end_date, - ) - self.session.add(bpv_table) - - # Add all baseline variables - for baseline_variable in seed_objects.baseline_variables: - from .baseline_variable_table import BaselineVariableTable - - bv_table = BaselineVariableTable( - id=baseline_variable.id, - model_id=baseline_variable.model_version.model.id, # Add model_id - model_version_id=baseline_variable.model_version.id, - entity=baseline_variable.entity, - label=baseline_variable.label, - description=baseline_variable.description, - data_type=(lambda bv: compress_data(bv.data_type))( - baseline_variable - ) - if baseline_variable.data_type - else None, - ) - self.session.add(bv_table) - - # Commit everything at once - self.session.commit() diff --git a/src/policyengine/database/dataset_table.py b/src/policyengine/database/dataset_table.py deleted file mode 100644 index cf22cda8..00000000 --- a/src/policyengine/database/dataset_table.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.models import Dataset, Model, VersionedDataset -from policyengine.utils.compress import compress_data, decompress_data - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class DatasetTable(SQLModel, table=True): - __tablename__ = "datasets" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - name: str = Field(nullable=False) - description: str | None = Field(default=None) - version: str | None = Field(default=None) - versioned_dataset_id: str | None = Field( - default=None, foreign_key="versioned_datasets.id", ondelete="SET NULL" - ) - year: int | None = Field(default=None) - data: bytes | None = Field(default=None) - model_id: str | None = Field( - default=None, foreign_key="models.id", ondelete="SET NULL" - ) - - @classmethod - def convert_from_model(cls, model: Dataset, database: "Database" = None) -> "DatasetTable": - """Convert a Dataset instance to a DatasetTable instance. - - Args: - model: The Dataset instance to convert - database: The database instance for persisting foreign objects if needed - - Returns: - A DatasetTable instance - """ - # Ensure foreign objects are persisted if database is provided - if database: - if model.versioned_dataset: - database.set(model.versioned_dataset, commit=False) - if model.model: - database.set(model.model, commit=False) - - return cls( - id=model.id, - name=model.name, - description=model.description, - version=model.version, - versioned_dataset_id=model.versioned_dataset.id if model.versioned_dataset else None, - year=model.year, - data=compress_data(model.data) if model.data else None, - model_id=model.model.id if model.model else None, - ) - - def convert_to_model(self, database: "Database" = None) -> Dataset: - """Convert this DatasetTable instance to a Dataset instance. - - Args: - database: The database instance for resolving foreign keys - - Returns: - A Dataset instance - """ - # Resolve foreign keys - versioned_dataset = None - model = None - - if database: - if self.versioned_dataset_id: - versioned_dataset = database.get(VersionedDataset, id=self.versioned_dataset_id) - if self.model_id: - model = database.get(Model, id=self.model_id) - - return Dataset( - id=self.id, - name=self.name, - description=self.description, - version=self.version, - versioned_dataset=versioned_dataset, - year=self.year, - data=decompress_data(self.data) if self.data else None, - model=model, - ) - - -dataset_table_link = TableLink( - model_cls=Dataset, - table_cls=DatasetTable, -) diff --git a/src/policyengine/database/dynamic_table.py b/src/policyengine/database/dynamic_table.py deleted file mode 100644 index 086e6bd9..00000000 --- a/src/policyengine/database/dynamic_table.py +++ /dev/null @@ -1,68 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.models import Dynamic -from policyengine.utils.compress import compress_data, decompress_data - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class DynamicTable(SQLModel, table=True): - __tablename__ = "dynamics" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - name: str = Field(nullable=False) - description: str | None = Field(default=None) - simulation_modifier: bytes | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - @classmethod - def convert_from_model(cls, model: Dynamic, database: "Database" = None) -> "DynamicTable": - """Convert a Dynamic instance to a DynamicTable instance. - - Args: - model: The Dynamic instance to convert - database: The database instance (not used for this table) - - Returns: - A DynamicTable instance - """ - return cls( - id=model.id, - name=model.name, - description=model.description, - simulation_modifier=compress_data(model.simulation_modifier) if model.simulation_modifier else None, - created_at=model.created_at, - updated_at=model.updated_at, - ) - - def convert_to_model(self, database: "Database" = None) -> Dynamic: - """Convert this DynamicTable instance to a Dynamic instance. - - Args: - database: The database instance (not used for this table) - - Returns: - A Dynamic instance - """ - return Dynamic( - id=self.id, - name=self.name, - description=self.description, - simulation_modifier=decompress_data(self.simulation_modifier) if self.simulation_modifier else None, - created_at=self.created_at, - updated_at=self.updated_at, - ) - - -dynamic_table_link = TableLink( - model_cls=Dynamic, - table_cls=DynamicTable, -) diff --git a/src/policyengine/database/link.py b/src/policyengine/database/link.py deleted file mode 100644 index 2bb1a041..00000000 --- a/src/policyengine/database/link.py +++ /dev/null @@ -1,8 +0,0 @@ -from pydantic import BaseModel -from sqlmodel import SQLModel - - -class TableLink(BaseModel): - """Simple registry mapping model classes to table classes.""" - model_cls: type[BaseModel] - table_cls: type[SQLModel] diff --git a/src/policyengine/database/model_table.py b/src/policyengine/database/model_table.py deleted file mode 100644 index 220238c8..00000000 --- a/src/policyengine/database/model_table.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import TYPE_CHECKING - -from sqlmodel import Field, SQLModel - -from policyengine.models import Model -from policyengine.utils.compress import compress_data, decompress_data - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ModelTable(SQLModel, table=True, extend_existing=True): - __tablename__ = "models" - - id: str = Field(primary_key=True) - name: str = Field(nullable=False) - description: str | None = Field(default=None) - simulation_function: bytes - - @classmethod - def convert_from_model(cls, model: Model, database: "Database" = None) -> "ModelTable": - """Convert a Model instance to a ModelTable instance. - - Args: - model: The Model instance to convert - database: The database instance (not used for this table) - - Returns: - A ModelTable instance - """ - return cls( - id=model.id, - name=model.name, - description=model.description, - simulation_function=compress_data(model.simulation_function), - ) - - def convert_to_model(self, database: "Database" = None) -> Model: - """Convert this ModelTable instance to a Model instance. - - Args: - database: The database instance (not used for this table) - - Returns: - A Model instance - """ - return Model( - id=self.id, - name=self.name, - description=self.description, - simulation_function=decompress_data(self.simulation_function), - ) - - -model_table_link = TableLink( - model_cls=Model, - table_cls=ModelTable, -) diff --git a/src/policyengine/database/model_version_table.py b/src/policyengine/database/model_version_table.py deleted file mode 100644 index 86d19fed..00000000 --- a/src/policyengine/database/model_version_table.py +++ /dev/null @@ -1,73 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.models import Model, ModelVersion - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ModelVersionTable(SQLModel, table=True): - __tablename__ = "model_versions" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - model_id: str = Field(foreign_key="models.id", ondelete="CASCADE") - version: str = Field(nullable=False) - description: str | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - - @classmethod - def convert_from_model(cls, model: ModelVersion, database: "Database" = None) -> "ModelVersionTable": - """Convert a ModelVersion instance to a ModelVersionTable instance. - - Args: - model: The ModelVersion instance to convert - database: The database instance for persisting the model if needed - - Returns: - A ModelVersionTable instance - """ - # Ensure the Model is persisted if database is provided - if database and model.model: - database.set(model.model, commit=False) - - return cls( - id=model.id, - model_id=model.model.id if model.model else None, - version=model.version, - description=model.description, - created_at=model.created_at, - ) - - def convert_to_model(self, database: "Database" = None) -> ModelVersion: - """Convert this ModelVersionTable instance to a ModelVersion instance. - - Args: - database: The database instance for resolving the model foreign key - - Returns: - A ModelVersion instance - """ - # Resolve the model foreign key - model = None - if database and self.model_id: - model = database.get(Model, id=self.model_id) - - return ModelVersion( - id=self.id, - model=model, - version=self.version, - description=self.description, - created_at=self.created_at, - ) - - -model_version_table_link = TableLink( - model_cls=ModelVersion, - table_cls=ModelVersionTable, -) diff --git a/src/policyengine/database/parameter_table.py b/src/policyengine/database/parameter_table.py deleted file mode 100644 index aef88e5a..00000000 --- a/src/policyengine/database/parameter_table.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import TYPE_CHECKING - -from sqlmodel import Field, SQLModel - -from policyengine.models import Model, Parameter - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ParameterTable(SQLModel, table=True): - __tablename__ = "parameters" - __table_args__ = ({"extend_existing": True},) - - id: str = Field(primary_key=True) # Parameter name - model_id: str = Field( - primary_key=True, foreign_key="models.id" - ) # Part of composite key - description: str | None = Field(default=None) - data_type: str | None = Field(nullable=True) # Data type name - label: str | None = Field(default=None) - unit: str | None = Field(default=None) - - @classmethod - def convert_from_model(cls, model: Parameter, database: "Database" = None) -> "ParameterTable": - """Convert a Parameter instance to a ParameterTable instance. - - Args: - model: The Parameter instance to convert - database: The database instance for persisting the model if needed - - Returns: - A ParameterTable instance - """ - # Ensure the Model is persisted if database is provided - if database and model.model: - database.set(model.model, commit=False) - - return cls( - id=model.id, - model_id=model.model.id if model.model else None, - description=model.description, - data_type=model.data_type.__name__ if model.data_type else None, - label=model.label, - unit=model.unit, - ) - - def convert_to_model(self, database: "Database" = None) -> Parameter: - """Convert this ParameterTable instance to a Parameter instance. - - Args: - database: The database instance for resolving the model foreign key - - Returns: - A Parameter instance - """ - from .model_table import ModelTable - from sqlmodel import select - - # Resolve the model foreign key - model = None - if database and self.model_id: - model_table = database.session.exec( - select(ModelTable).where(ModelTable.id == self.model_id) - ).first() - if model_table: - model = model_table.convert_to_model(database) - - # Convert data_type string back to type - data_type = None - if self.data_type: - try: - data_type = eval(self.data_type) - except: - data_type = None - - return Parameter( - id=self.id, - description=self.description, - data_type=data_type, - model=model, - label=self.label, - unit=self.unit, - ) - - -parameter_table_link = TableLink( - model_cls=Parameter, - table_cls=ParameterTable, -) diff --git a/src/policyengine/database/parameter_value_table.py b/src/policyengine/database/parameter_value_table.py deleted file mode 100644 index 6bfd60dd..00000000 --- a/src/policyengine/database/parameter_value_table.py +++ /dev/null @@ -1,107 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING, Any -from uuid import uuid4 - -from sqlmodel import JSON, Column, Field, SQLModel - -from policyengine.models import Parameter, ParameterValue - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class ParameterValueTable(SQLModel, table=True): - __tablename__ = "parameter_values" - __table_args__ = ({"extend_existing": True},) - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - parameter_id: str = Field(nullable=False) # Part of composite foreign key - model_id: str = Field(nullable=False) # Part of composite foreign key - policy_id: str | None = Field(default=None, foreign_key="policies.id", ondelete="CASCADE") # Link to policy - value: Any | None = Field( - default=None, sa_column=Column(JSON) - ) # JSON field for any type - start_date: datetime = Field(nullable=False) - end_date: datetime | None = Field(default=None) - - @classmethod - def convert_from_model(cls, model: ParameterValue, database: "Database" = None) -> "ParameterValueTable": - """Convert a ParameterValue instance to a ParameterValueTable instance. - - Args: - model: The ParameterValue instance to convert - database: The database instance for persisting the parameter if needed - - Returns: - A ParameterValueTable instance - """ - import math - - # Ensure the Parameter is persisted if database is provided - if database and model.parameter: - database.set(model.parameter, commit=False) - - # Handle special float values - value = model.value - if isinstance(value, float): - if math.isinf(value): - value = "Infinity" if value > 0 else "-Infinity" - elif math.isnan(value): - value = "NaN" - - return cls( - id=model.id, - parameter_id=model.parameter.id if model.parameter else None, - model_id=model.parameter.model.id if model.parameter and model.parameter.model else None, - value=value, - start_date=model.start_date, - end_date=model.end_date, - ) - - def convert_to_model(self, database: "Database" = None) -> ParameterValue: - """Convert this ParameterValueTable instance to a ParameterValue instance. - - Args: - database: The database instance for resolving the parameter foreign key - - Returns: - A ParameterValue instance - """ - from .parameter_table import ParameterTable - from sqlmodel import select - - # Resolve the parameter foreign key - parameter = None - if database and self.parameter_id and self.model_id: - param_table = database.session.exec( - select(ParameterTable).where( - ParameterTable.id == self.parameter_id, - ParameterTable.model_id == self.model_id - ) - ).first() - parameter = param_table.convert_to_model(database) - - # Handle special string values - value = self.value - if value == "Infinity": - value = float("inf") - elif value == "-Infinity": - value = float("-inf") - elif value == "NaN": - value = float("nan") - - return ParameterValue( - id=self.id, - parameter=parameter, - value=value, - start_date=self.start_date, - end_date=self.end_date, - ) - - -parameter_value_table_link = TableLink( - model_cls=ParameterValue, - table_cls=ParameterValueTable, -) diff --git a/src/policyengine/database/policy_table.py b/src/policyengine/database/policy_table.py deleted file mode 100644 index 0ae381e4..00000000 --- a/src/policyengine/database/policy_table.py +++ /dev/null @@ -1,136 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.models import Policy -from policyengine.utils.compress import compress_data, decompress_data - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class PolicyTable(SQLModel, table=True): - __tablename__ = "policies" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - name: str = Field(nullable=False) - description: str | None = Field(default=None) - simulation_modifier: bytes | None = Field(default=None) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - @classmethod - def convert_from_model(cls, model: Policy, database: "Database" = None) -> "PolicyTable": - """Convert a Policy instance to a PolicyTable instance. - - Args: - model: The Policy instance to convert - database: The database instance for persisting nested objects - - Returns: - A PolicyTable instance - """ - policy_table = cls( - id=model.id, - name=model.name, - description=model.description, - simulation_modifier=compress_data(model.simulation_modifier) if model.simulation_modifier else None, - created_at=model.created_at, - updated_at=model.updated_at, - ) - - # Handle nested parameter values if database is provided - if database and model.parameter_values: - from .parameter_value_table import ParameterValueTable - from sqlmodel import select - - # First ensure the policy table is saved to the database - # This is necessary so the foreign key constraint is satisfied - # Check if it already exists - existing_policy = database.session.exec( - select(PolicyTable).where(PolicyTable.id == model.id) - ).first() - - if not existing_policy: - database.session.add(policy_table) - database.session.flush() - - # Track which parameter value IDs we want to keep - desired_pv_ids = {pv.id for pv in model.parameter_values} - - # Delete only parameter values linked to this policy that are NOT in the new list - existing_pvs = database.session.exec( - select(ParameterValueTable).where(ParameterValueTable.policy_id == model.id) - ).all() - for pv in existing_pvs: - if pv.id not in desired_pv_ids: - database.session.delete(pv) - - # Now save/update the parameter values - for param_value in model.parameter_values: - # Check if this parameter value already exists in the database - existing_pv = database.session.exec( - select(ParameterValueTable).where(ParameterValueTable.id == param_value.id) - ).first() - - if existing_pv: - # Update existing parameter value - pv_table = ParameterValueTable.convert_from_model(param_value, database) - existing_pv.parameter_id = pv_table.parameter_id - existing_pv.model_id = pv_table.model_id - existing_pv.policy_id = model.id - existing_pv.value = pv_table.value - existing_pv.start_date = pv_table.start_date - existing_pv.end_date = pv_table.end_date - else: - # Create new parameter value - pv_table = ParameterValueTable.convert_from_model(param_value, database) - pv_table.policy_id = model.id # Link to this policy - database.session.add(pv_table) - database.session.flush() - - return policy_table - - def convert_to_model(self, database: "Database" = None) -> Policy: - """Convert this PolicyTable instance to a Policy instance. - - Args: - database: The database instance for loading nested objects - - Returns: - A Policy instance - """ - # Load nested parameter values if database is provided - parameter_values = [] - if database: - from .parameter_value_table import ParameterValueTable - from sqlmodel import select - - # Query for all parameter values linked to this policy - pv_tables = database.session.exec( - select(ParameterValueTable).where(ParameterValueTable.policy_id == self.id) - ).all() - - # Convert each one to a model - for pv_table in pv_tables: - parameter_values.append(pv_table.convert_to_model(database)) - - return Policy( - id=self.id, - name=self.name, - description=self.description, - parameter_values=parameter_values, - simulation_modifier=decompress_data(self.simulation_modifier) if self.simulation_modifier else None, - created_at=self.created_at, - updated_at=self.updated_at, - ) - - -policy_table_link = TableLink( - model_cls=Policy, - table_cls=PolicyTable, -) diff --git a/src/policyengine/database/simulation_table.py b/src/policyengine/database/simulation_table.py deleted file mode 100644 index de6eae58..00000000 --- a/src/policyengine/database/simulation_table.py +++ /dev/null @@ -1,231 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING -from uuid import uuid4 - -from sqlmodel import Field, SQLModel - -from policyengine.models import Dataset, Dynamic, Model, ModelVersion, Policy, Simulation -from policyengine.utils.compress import compress_data, decompress_data - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class SimulationTable(SQLModel, table=True): - __tablename__ = "simulations" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - created_at: datetime = Field(default_factory=datetime.now) - updated_at: datetime = Field(default_factory=datetime.now) - - policy_id: str | None = Field( - default=None, foreign_key="policies.id", ondelete="SET NULL" - ) - dynamic_id: str | None = Field( - default=None, foreign_key="dynamics.id", ondelete="SET NULL" - ) - dataset_id: str = Field(foreign_key="datasets.id", ondelete="CASCADE") - model_id: str = Field(foreign_key="models.id", ondelete="CASCADE") - model_version_id: str | None = Field( - default=None, foreign_key="model_versions.id", ondelete="SET NULL" - ) - - result: bytes | None = Field(default=None) - error: str | None = Field(default=None) - - @classmethod - def convert_from_model(cls, model: Simulation, database: "Database" = None) -> "SimulationTable": - """Convert a Simulation instance to a SimulationTable instance. - - Args: - model: The Simulation instance to convert - database: The database instance for persisting foreign objects if needed - - Returns: - A SimulationTable instance - """ - # Ensure all foreign objects are persisted if database is provided - if database: - if model.policy: - database.set(model.policy, commit=False) - if model.dynamic: - database.set(model.dynamic, commit=False) - if model.dataset: - database.set(model.dataset, commit=False) - if model.model: - database.set(model.model, commit=False) - if model.model_version: - database.set(model.model_version, commit=False) - - sim_table = cls( - id=model.id, - created_at=model.created_at, - updated_at=model.updated_at, - policy_id=model.policy.id if model.policy else None, - dynamic_id=model.dynamic.id if model.dynamic else None, - dataset_id=model.dataset.id if model.dataset else None, - model_id=model.model.id if model.model else None, - model_version_id=model.model_version.id if model.model_version else None, - result=compress_data(model.result) if model.result else None, - error=getattr(model, 'error', None), - ) - - # Handle nested aggregates if database is provided - if database and model.aggregates: - from .aggregate import AggregateTable - from sqlmodel import select - - # First ensure the simulation table is saved to the database - # This is necessary so the foreign key constraint is satisfied - # Check if it already exists - existing_sim = database.session.exec( - select(SimulationTable).where(SimulationTable.id == model.id) - ).first() - - if not existing_sim: - database.session.add(sim_table) - database.session.flush() - - # Track which aggregate IDs we want to keep - desired_agg_ids = {agg.id for agg in model.aggregates} - - # Delete only aggregates linked to this simulation that are NOT in the new list - existing_aggs = database.session.exec( - select(AggregateTable).where(AggregateTable.simulation_id == model.id) - ).all() - for agg in existing_aggs: - if agg.id not in desired_agg_ids: - database.session.delete(agg) - - # Now save/update the aggregates - for aggregate in model.aggregates: - # Check if this aggregate already exists in the database - existing_agg = database.session.exec( - select(AggregateTable).where(AggregateTable.id == aggregate.id) - ).first() - - if existing_agg: - # Update existing aggregate - agg_table = AggregateTable.convert_from_model(aggregate, database) - existing_agg.simulation_id = agg_table.simulation_id - existing_agg.entity = agg_table.entity - existing_agg.variable_name = agg_table.variable_name - existing_agg.year = agg_table.year - existing_agg.filter_variable_name = agg_table.filter_variable_name - existing_agg.filter_variable_value = agg_table.filter_variable_value - existing_agg.filter_variable_leq = agg_table.filter_variable_leq - existing_agg.filter_variable_geq = agg_table.filter_variable_geq - existing_agg.aggregate_function = agg_table.aggregate_function - existing_agg.value = agg_table.value - else: - # Create new aggregate - agg_table = AggregateTable.convert_from_model(aggregate, database) - database.session.add(agg_table) - database.session.flush() - - return sim_table - - def convert_to_model(self, database: "Database" = None) -> Simulation: - """Convert this SimulationTable instance to a Simulation instance. - - Args: - database: The database instance for resolving foreign keys - - Returns: - A Simulation instance - """ - from sqlmodel import select - - from .model_version_table import ModelVersionTable - from .policy_table import PolicyTable - from .dataset_table import DatasetTable - from .model_table import ModelTable - from .dynamic_table import DynamicTable - - # Resolve all foreign keys - policy = None - dynamic = None - dataset = None - model = None - model_version = None - - if database: - if self.policy_id: - policy_table = database.session.exec( - select(PolicyTable).where(PolicyTable.id == self.policy_id) - ).first() - if policy_table: - policy = policy_table.convert_to_model(database) - - if self.dynamic_id: - try: - dynamic_table = database.session.exec( - select(DynamicTable).where(DynamicTable.id == self.dynamic_id) - ).first() - if dynamic_table: - dynamic = dynamic_table.convert_to_model(database) - except: - # Dynamic table might not be defined yet - dynamic = database.get(Dynamic, id=self.dynamic_id) - - if self.dataset_id: - dataset_table = database.session.exec( - select(DatasetTable).where(DatasetTable.id == self.dataset_id) - ).first() - if dataset_table: - dataset = dataset_table.convert_to_model(database) - - if self.model_id: - model_table = database.session.exec( - select(ModelTable).where(ModelTable.id == self.model_id) - ).first() - if model_table: - model = model_table.convert_to_model(database) - - if self.model_version_id: - version_table = database.session.exec( - select(ModelVersionTable).where(ModelVersionTable.id == self.model_version_id) - ).first() - if version_table: - model_version = version_table.convert_to_model(database) - - # Load aggregates - aggregates = [] - if database: - from .aggregate import AggregateTable - from sqlmodel import select - - agg_tables = database.session.exec( - select(AggregateTable).where(AggregateTable.simulation_id == self.id) - ).all() - - for agg_table in agg_tables: - # Don't pass database to avoid circular reference issues - # The simulation reference will be set separately - agg_model = agg_table.convert_to_model(None) - aggregates.append(agg_model) - - sim = Simulation( - id=self.id, - created_at=self.created_at, - updated_at=self.updated_at, - policy=policy, - dynamic=dynamic, - dataset=dataset, - model=model, - model_version=model_version, - result=decompress_data(self.result) if self.result else None, - aggregates=aggregates, - ) - # Add error as dynamic attribute if present - if self.error: - sim.error = self.error - return sim - - -simulation_table_link = TableLink( - model_cls=Simulation, - table_cls=SimulationTable, -) diff --git a/src/policyengine/database/table_mixin.py b/src/policyengine/database/table_mixin.py deleted file mode 100644 index a29cdeb6..00000000 --- a/src/policyengine/database/table_mixin.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar - -from pydantic import BaseModel -from sqlmodel import SQLModel - -if TYPE_CHECKING: - from .database import Database - -T = TypeVar("T", bound=BaseModel) - - -class TableConversionMixin: - """Mixin class for SQLModel tables to provide conversion methods between table instances and Pydantic models.""" - - _model_cls: ClassVar[type[BaseModel]] = None - _foreign_key_fields: ClassVar[dict[str, type[BaseModel]]] = {} - - @classmethod - def convert_from_model(cls, model: BaseModel, database: "Database" = None) -> SQLModel: - """Convert a Pydantic model instance to a table instance, resolving foreign objects to IDs. - - Args: - model: The Pydantic model instance to convert - database: The database instance for resolving foreign objects (optional) - - Returns: - An instance of the SQLModel table class - """ - data = {} - - for field_name in cls.__annotations__.keys(): - # Check if this field is a foreign key that needs resolution - if field_name in cls._foreign_key_fields: - # Extract ID from the nested object - nested_obj = getattr(model, field_name.replace("_id", ""), None) - if nested_obj: - # If we need to ensure the foreign object exists in DB - if database: - database.set(nested_obj, commit=False) - data[field_name] = nested_obj.id if hasattr(nested_obj, "id") else None - else: - data[field_name] = None - elif hasattr(model, field_name): - # Direct field mapping - data[field_name] = getattr(model, field_name) - - return cls(**data) - - @classmethod - def convert_to_model(cls, table_instance: SQLModel, database: "Database" = None) -> BaseModel: - """Convert a table instance to a Pydantic model, resolving foreign key IDs to objects. - - Args: - table_instance: The SQLModel table instance to convert - database: The database instance for resolving foreign keys (required if foreign keys exist) - - Returns: - An instance of the Pydantic model class - """ - if cls._model_cls is None: - raise ValueError(f"Model class not set for {cls.__name__}") - - data = {} - - for field_name in cls._model_cls.__annotations__.keys(): - # Check if we need to resolve a foreign key - fk_field = f"{field_name}_id" - if fk_field in cls._foreign_key_fields and database: - # Resolve the foreign key to an object - fk_id = getattr(table_instance, fk_field, None) - if fk_id: - foreign_model_cls = cls._foreign_key_fields[fk_field] - data[field_name] = database.get(foreign_model_cls, id=fk_id) - else: - data[field_name] = None - elif hasattr(table_instance, field_name): - # Direct field mapping - data[field_name] = getattr(table_instance, field_name) - - return cls._model_cls(**data) \ No newline at end of file diff --git a/src/policyengine/database/versioned_dataset_table.py b/src/policyengine/database/versioned_dataset_table.py deleted file mode 100644 index 4e1524c9..00000000 --- a/src/policyengine/database/versioned_dataset_table.py +++ /dev/null @@ -1,45 +0,0 @@ -from uuid import uuid4 - -from sqlmodel import Field, SQLModel -from typing import TYPE_CHECKING - -from policyengine.models import VersionedDataset - -from .link import TableLink - -if TYPE_CHECKING: - from .database import Database - - -class VersionedDatasetTable(SQLModel, table=True): - __tablename__ = "versioned_datasets" - - id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) - name: str = Field(nullable=False) - description: str = Field(nullable=False) - model_id: str | None = Field( - default=None, foreign_key="models.id", ondelete="SET NULL" - ) - - @classmethod - def convert_from_model(cls, model: VersionedDataset, database: "Database" = None) -> "VersionedDatasetTable": - """Convert a VersionedDataset instance to a VersionedDatasetTable instance.""" - return cls( - id=model.id, - name=model.name, - description=model.description, - ) - - def convert_to_model(self, database: "Database" = None) -> VersionedDataset: - """Convert this VersionedDatasetTable instance to a VersionedDataset instance.""" - return VersionedDataset( - id=self.id, - name=self.name, - description=self.description, - ) - - -versioned_dataset_table_link = TableLink( - model_cls=VersionedDataset, - table_cls=VersionedDatasetTable, -) diff --git a/src/policyengine/models/__init__.py b/src/policyengine/models/__init__.py index 24fb823c..52f87c40 100644 --- a/src/policyengine/models/__init__.py +++ b/src/policyengine/models/__init__.py @@ -4,11 +4,11 @@ from .baseline_parameter_value import ( BaselineParameterValue as BaselineParameterValue, ) -from .baseline_variable import BaselineVariable as BaselineVariable +from .variable import BaselineVariable as BaselineVariable from .dataset import Dataset as Dataset from .dynamic import Dynamic as Dynamic -from .model import Model as Model -from .model_version import ModelVersion as ModelVersion +from .tax_benefit_model import Model as Model +from .tax_benefit_model_version import ModelVersion as ModelVersion from .parameter import Parameter as Parameter from .parameter_value import ParameterValue as ParameterValue from .policy import Policy as Policy @@ -25,7 +25,7 @@ policyengine_us_model as policyengine_us_model, ) from .simulation import Simulation as Simulation -from .versioned_dataset import VersionedDataset as VersionedDataset +from .dataset_version import VersionedDataset as VersionedDataset # Rebuild models to handle circular references from .aggregate import Aggregate diff --git a/src/policyengine/models/aggregate.py b/src/policyengine/models/aggregate.py deleted file mode 100644 index 003b9aa4..00000000 --- a/src/policyengine/models/aggregate.py +++ /dev/null @@ -1,297 +0,0 @@ -from enum import Enum -from typing import TYPE_CHECKING, Any, Literal -from uuid import uuid4 - -import pandas as pd -from microdf import MicroDataFrame, MicroSeries -from pydantic import BaseModel, ConfigDict, Field, SkipValidation - -if TYPE_CHECKING: - from policyengine.models import Simulation - - -class AggregateType(str, Enum): - SUM = "sum" - MEAN = "mean" - MEDIAN = "median" - COUNT = "count" - - -class DataEngine: - """Clean data processing engine for aggregations.""" - - def __init__(self, tables: dict): - """Initialize with simulation result tables.""" - self.tables = self._prepare_tables(tables) - - @staticmethod - def _prepare_tables(tables: dict) -> dict[str, pd.DataFrame]: - """Convert tables to DataFrames with MicroDataFrame for weighted columns.""" - prepared = {} - for name, table in tables.items(): - df = pd.DataFrame(table.copy() if hasattr(table, 'copy') else table) - weight_col = f"{name}_weight" - if weight_col in df.columns: - df = MicroDataFrame(df, weights=weight_col) - prepared[name] = df - return prepared - - def infer_entity(self, variable: str) -> str: - """Infer which entity contains a variable.""" - for entity, table in self.tables.items(): - if variable in table.columns: - return entity - raise ValueError(f"Variable {variable} not found in any entity") - - def get_variable_series( - self, - variable: str, - target_entity: str, - filters: dict[str, Any] | None = None - ) -> pd.Series: - """ - Get variable series at target entity level, with optional filtering. - - Handles cross-entity mapping automatically. - """ - # Find source entity - source_entity = self.infer_entity(variable) - - # Apply filters first (on target entity) - if filters: - mask = self._build_filter_mask(filters, target_entity) - target_table = self.tables[target_entity][mask] - else: - target_table = self.tables[target_entity] - - # Get variable (map if needed) - if source_entity == target_entity: - return target_table[variable] - else: - # Map across entities - source_series = self.tables[source_entity][variable] - mapped_series = self._map_variable(source_series, source_entity, target_entity) - # Apply filter mask to mapped series - if filters: - return mapped_series[mask] - return mapped_series - - def _build_filter_mask(self, filters: dict[str, Any], target_entity: str) -> pd.Series: - """Build boolean mask from filter specification.""" - target_table = self.tables[target_entity] - mask = pd.Series([True] * len(target_table), index=target_table.index) - - filter_variable = filters.get('variable') - if not filter_variable: - return mask - - # Get filter series (map if cross-entity) - filter_entity = self.infer_entity(filter_variable) - if filter_entity == target_entity: - filter_series = target_table[filter_variable] - else: - filter_series = self._map_variable( - self.tables[filter_entity][filter_variable], - filter_entity, - target_entity - ) - - # Apply value filters - if 'value' in filters and filters['value'] is not None: - mask &= filter_series == filters['value'] - - if 'leq' in filters and filters['leq'] is not None: - mask &= filter_series <= filters['leq'] - - if 'geq' in filters and filters['geq'] is not None: - mask &= filter_series >= filters['geq'] - - # Apply quantile filters - if 'quantile_leq' in filters and filters['quantile_leq'] is not None: - threshold = filter_series.quantile(filters['quantile_leq']) - mask &= filter_series <= threshold - - if 'quantile_geq' in filters and filters['quantile_geq'] is not None: - threshold = filter_series.quantile(filters['quantile_geq']) - mask &= filter_series >= threshold - - return mask - - def _map_variable( - self, - series: pd.Series, - source_entity: str, - target_entity: str - ) -> pd.Series: - """Map a variable from source to target entity.""" - if source_entity == target_entity: - return series - - # Default entity links (can be overridden) - person_links = { - "benunit": "person_benunit_id", - "household": "person_household_id", - "family": "person_family_id", - "tax_unit": "person_tax_unit_id", - "spm_unit": "person_spm_unit_id", - } - - # Group to person: copy values down - if source_entity != "person" and target_entity == "person": - link_col = person_links.get(source_entity) - if not link_col: - raise ValueError(f"No link from person to {source_entity}") - - person_table = self.tables["person"] - if link_col not in person_table.columns: - raise ValueError(f"Link column {link_col} not in person table") - - group_values = series.values - person_group_ids = person_table[link_col].values - mapped_values = [ - group_values[int(gid)] if int(gid) < len(group_values) else 0 - for gid in person_group_ids - ] - - # Return MicroSeries with person weights - weight_col = f"{target_entity}_weight" - if isinstance(person_table, MicroDataFrame) and weight_col in person_table.columns: - return MicroSeries(mapped_values, weights=person_table[weight_col]) - return pd.Series(mapped_values, index=person_table.index) - - # Person to group: aggregate up - elif source_entity == "person" and target_entity != "person": - link_col = person_links.get(target_entity) - if not link_col: - raise ValueError(f"No link from person to {target_entity}") - - person_table = self.tables["person"] - if link_col not in person_table.columns: - raise ValueError(f"Link column {link_col} not in person table") - - grouped = pd.DataFrame({ - link_col: person_table[link_col], - 'value': series - }).groupby(link_col)['value'].sum() - - target_table = self.tables[target_entity] - mapped_values = [grouped.get(i, 0) for i in range(len(target_table))] - - # Return MicroSeries with target entity weights - weight_col = f"{target_entity}_weight" - if isinstance(target_table, MicroDataFrame) and weight_col in target_table.columns: - return MicroSeries(mapped_values, weights=target_table[weight_col]) - return pd.Series(mapped_values, index=target_table.index) - - # Group to group: via person - else: - person_series = self._map_variable(series, source_entity, "person") - return self._map_variable(person_series, "person", target_entity) - - @staticmethod - def aggregate(series: pd.Series, function: AggregateType) -> float: - """Apply aggregation function to series.""" - if len(series) == 0: - return 0.0 - - # Avoid double-weighting weight columns - is_weight = ( - hasattr(series, 'name') and - series.name and - 'weight' in str(series.name).lower() - ) - - if function == AggregateType.SUM: - if is_weight: - return float(pd.Series(series.values).sum()) - return float(series.sum()) - elif function == AggregateType.MEAN: - return float(series.mean()) - elif function == AggregateType.MEDIAN: - return float(series.median()) - elif function == AggregateType.COUNT: - # For MicroSeries, sum the weights to get weighted population count - if isinstance(series, MicroSeries): - return float(series.weights.sum()) - return float(len(series)) - else: - raise ValueError(f"Unknown aggregate function: {function}") - - -class Aggregate(BaseModel): - """Aggregate calculation.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - - id: str = Field(default_factory=lambda: str(uuid4())) - simulation: SkipValidation["Simulation | None"] = None - entity: str | None = None - variable_name: str - year: int | None = None - filter_variable_name: str | None = None - filter_variable_value: Any | None = None - filter_variable_leq: float | None = None - filter_variable_geq: float | None = None - filter_variable_quantile_leq: float | None = None - filter_variable_quantile_geq: float | None = None - aggregate_function: Literal[ - AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT - ] - reportelement_id: str | None = None - value: float | None = None - - @staticmethod - def run(aggregates: list["Aggregate"]) -> list["Aggregate"]: - """Process aggregates efficiently by batching those with same simulation.""" - # Group by simulation for batch processing - by_simulation = {} - for agg in aggregates: - sim_id = id(agg.simulation) if agg.simulation else None - if sim_id not in by_simulation: - by_simulation[sim_id] = [] - by_simulation[sim_id].append(agg) - - results = [] - for sim_aggregates in by_simulation.values(): - if not sim_aggregates: - continue - - simulation = sim_aggregates[0].simulation - if simulation is None: - raise ValueError("Aggregate has no simulation") - - # Create data engine once per simulation (batch optimization) - engine = DataEngine(simulation.result) - - # Process each aggregate - for agg in sim_aggregates: - if agg.year is None: - agg.year = simulation.dataset.year - - # Infer entity if not specified - if agg.entity is None: - agg.entity = engine.infer_entity(agg.variable_name) - - # Build filter specification - filters = None - if agg.filter_variable_name: - filters = { - 'variable': agg.filter_variable_name, - 'value': agg.filter_variable_value, - 'leq': agg.filter_variable_leq, - 'geq': agg.filter_variable_geq, - 'quantile_leq': agg.filter_variable_quantile_leq, - 'quantile_geq': agg.filter_variable_quantile_geq, - } - - # Get variable series with filters - series = engine.get_variable_series( - agg.variable_name, - agg.entity, - filters - ) - - # Compute aggregate - agg.value = engine.aggregate(series, agg.aggregate_function) - results.append(agg) - - return results diff --git a/src/policyengine/models/aggregate_change.py b/src/policyengine/models/aggregate_change.py deleted file mode 100644 index e869d8fc..00000000 --- a/src/policyengine/models/aggregate_change.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import TYPE_CHECKING, Any, Literal -from uuid import uuid4 - -from pydantic import BaseModel, ConfigDict, Field, SkipValidation - -from .aggregate import AggregateType, DataEngine - -if TYPE_CHECKING: - from policyengine.models import Simulation - - -class AggregateChange(BaseModel): - """Calculates the change in an aggregate between baseline and comparison simulations.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - - id: str = Field(default_factory=lambda: str(uuid4())) - baseline_simulation: SkipValidation["Simulation | None"] = None - comparison_simulation: SkipValidation["Simulation | None"] = None - entity: str | None = None - variable_name: str - year: int | None = None - filter_variable_name: str | None = None - filter_variable_value: Any | None = None - filter_variable_leq: float | None = None - filter_variable_geq: float | None = None - filter_variable_quantile_leq: float | None = None - filter_variable_quantile_geq: float | None = None - aggregate_function: Literal[ - AggregateType.SUM, AggregateType.MEAN, AggregateType.MEDIAN, AggregateType.COUNT - ] - reportelement_id: str | None = None - - baseline_value: float | None = None - comparison_value: float | None = None - change: float | None = None - relative_change: float | None = None - - @staticmethod - def run(aggregate_changes: list["AggregateChange"]) -> list["AggregateChange"]: - """Process aggregate changes efficiently by batching those with same simulation pair.""" - # Group by simulation pair for batch processing - by_sim_pair = {} - for agg_change in aggregate_changes: - if agg_change.baseline_simulation is None: - raise ValueError("AggregateChange missing baseline_simulation") - if agg_change.comparison_simulation is None: - raise ValueError("AggregateChange missing comparison_simulation") - - key = ( - id(agg_change.baseline_simulation), - id(agg_change.comparison_simulation) - ) - if key not in by_sim_pair: - by_sim_pair[key] = [] - by_sim_pair[key].append(agg_change) - - results = [] - for pair_aggregates in by_sim_pair.values(): - if not pair_aggregates: - continue - - # Get simulation objects - baseline_sim = pair_aggregates[0].baseline_simulation - comparison_sim = pair_aggregates[0].comparison_simulation - - # Create data engines once per simulation pair (batch optimization) - baseline_engine = DataEngine(baseline_sim.result) - comparison_engine = DataEngine(comparison_sim.result) - - # Process each aggregate change - for agg_change in pair_aggregates: - if agg_change.year is None: - agg_change.year = baseline_sim.dataset.year - - # Infer entity if not specified - if agg_change.entity is None: - agg_change.entity = baseline_engine.infer_entity(agg_change.variable_name) - - # Build filter specification - filters = None - if agg_change.filter_variable_name: - filters = { - 'variable': agg_change.filter_variable_name, - 'value': agg_change.filter_variable_value, - 'leq': agg_change.filter_variable_leq, - 'geq': agg_change.filter_variable_geq, - 'quantile_leq': agg_change.filter_variable_quantile_leq, - 'quantile_geq': agg_change.filter_variable_quantile_geq, - } - - # Get variable series with filters for both simulations - baseline_series = baseline_engine.get_variable_series( - agg_change.variable_name, - agg_change.entity, - filters - ) - comparison_series = comparison_engine.get_variable_series( - agg_change.variable_name, - agg_change.entity, - filters - ) - - # Compute aggregates - agg_change.baseline_value = baseline_engine.aggregate( - baseline_series, - agg_change.aggregate_function - ) - agg_change.comparison_value = comparison_engine.aggregate( - comparison_series, - agg_change.aggregate_function - ) - - # Calculate changes - agg_change.change = agg_change.comparison_value - agg_change.baseline_value - - if agg_change.baseline_value != 0: - agg_change.relative_change = ( - agg_change.change / abs(agg_change.baseline_value) - ) - else: - agg_change.relative_change = ( - None if agg_change.comparison_value == 0 else float('inf') - ) - - results.append(agg_change) - - return results diff --git a/src/policyengine/models/baseline_parameter_value.py b/src/policyengine/models/baseline_parameter_value.py deleted file mode 100644 index 8afb6e22..00000000 --- a/src/policyengine/models/baseline_parameter_value.py +++ /dev/null @@ -1,16 +0,0 @@ -from datetime import datetime -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .model_version import ModelVersion -from .parameter import Parameter - - -class BaselineParameterValue(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - parameter: Parameter - model_version: ModelVersion - value: float | int | str | bool | list | None = None - start_date: datetime - end_date: datetime | None = None diff --git a/src/policyengine/models/baseline_variable.py b/src/policyengine/models/baseline_variable.py deleted file mode 100644 index b0e739b1..00000000 --- a/src/policyengine/models/baseline_variable.py +++ /dev/null @@ -1,12 +0,0 @@ -from pydantic import BaseModel - -from .model_version import ModelVersion - - -class BaselineVariable(BaseModel): - id: str - model_version: ModelVersion - entity: str - label: str | None = None - description: str | None = None - data_type: type | None = None diff --git a/src/policyengine/models/dataset.py b/src/policyengine/models/dataset.py index 59dd626f..4fc899b5 100644 --- a/src/policyengine/models/dataset.py +++ b/src/policyengine/models/dataset.py @@ -3,16 +3,14 @@ from pydantic import BaseModel, Field -from .model import Model -from .versioned_dataset import VersionedDataset +from .tax_benefit_model import TaxBenefitModel +from .dataset_version import DatasetVersion class Dataset(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) name: str - description: str | None = None - version: str | None = None - versioned_dataset: VersionedDataset | None = None - year: int | None = None - data: Any | None = None - model: Model | None = None + description: str + dataset_version: DatasetVersion | None = None + filepath: str + tax_benefit_model: TaxBenefitModel = None diff --git a/src/policyengine/models/dataset_version.py b/src/policyengine/models/dataset_version.py new file mode 100644 index 00000000..7594c812 --- /dev/null +++ b/src/policyengine/models/dataset_version.py @@ -0,0 +1,15 @@ +from uuid import uuid4 + +from pydantic import BaseModel, Field +from typing import TYPE_CHECKING + +from .tax_benefit_model import TaxBenefitModel +if TYPE_CHECKING: + from .dataset import Dataset + + +class DatasetVersion(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + dataset: Dataset + description: str + tax_benefit_model: TaxBenefitModel = None diff --git a/src/policyengine/models/dynamic.py b/src/policyengine/models/dynamic.py index 40cf364f..1a88a5f6 100644 --- a/src/policyengine/models/dynamic.py +++ b/src/policyengine/models/dynamic.py @@ -3,13 +3,14 @@ from uuid import uuid4 from pydantic import BaseModel, Field +from .parameter_value import ParameterValue class Dynamic(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) name: str description: str | None = None - parameter_values: list[str] = [] + parameter_values: list[ParameterValue] = [] simulation_modifier: Callable | None = None created_at: datetime = Field(default_factory=datetime.now) updated_at: datetime = Field(default_factory=datetime.now) diff --git a/src/policyengine/models/model.py b/src/policyengine/models/model.py deleted file mode 100644 index 89cac9b8..00000000 --- a/src/policyengine/models/model.py +++ /dev/null @@ -1,126 +0,0 @@ -from collections.abc import Callable -from datetime import datetime -from typing import TYPE_CHECKING - -from pydantic import BaseModel - -if TYPE_CHECKING: - from .baseline_parameter_value import BaselineParameterValue - from .baseline_variable import BaselineVariable - from .parameter import Parameter - - -class Model(BaseModel): - id: str - name: str - description: str | None = None - simulation_function: Callable - - def create_seed_objects(self, model_version): - from policyengine_core.parameters import Parameter as CoreParameter - - from .baseline_parameter_value import BaselineParameterValue - from .baseline_variable import BaselineVariable - from .parameter import Parameter - - if self.id == "policyengine_uk": - from policyengine_uk.tax_benefit_system import system - elif self.id == "policyengine_us": - from policyengine_us.system import system - else: - raise ValueError("Unsupported model.") - - parameters = [] - baseline_parameter_values = [] - baseline_variables = [] - seen_parameter_ids = set() - - for parameter in system.parameters.get_descendants(): - # Skip if we've already processed this parameter ID - if parameter.name in seen_parameter_ids: - continue - seen_parameter_ids.add(parameter.name) - param = Parameter( - id=parameter.name, - description=parameter.description, - data_type=None, - model=self, - label=parameter.metadata.get("label"), - unit=parameter.metadata.get("unit"), - ) - parameters.append(param) - if isinstance(parameter, CoreParameter): - values = parameter.values_list[::-1] - param.data_type = type(values[-1].value) - for i in range(len(values)): - value_at_instant = values[i] - instant_str = safe_parse_instant_str( - value_at_instant.instant_str - ) - if i + 1 < len(values): - next_instant_str = safe_parse_instant_str( - values[i + 1].instant_str - ) - else: - next_instant_str = None - baseline_param_value = BaselineParameterValue( - parameter=param, - model_version=model_version, - value=value_at_instant.value, - start_date=instant_str, - end_date=next_instant_str, - ) - baseline_parameter_values.append(baseline_param_value) - - for variable in system.variables.values(): - baseline_variable = BaselineVariable( - id=variable.name, - model_version=model_version, - entity=variable.entity.key, - label=variable.label, - description=variable.documentation, - data_type=variable.value_type, - ) - baseline_variables.append(baseline_variable) - - return SeedObjects( - parameters=parameters, - baseline_parameter_values=baseline_parameter_values, - baseline_variables=baseline_variables, - ) - - -def safe_parse_instant_str(instant_str: str) -> datetime: - if instant_str == "0000-01-01": - return datetime(1, 1, 1) - else: - try: - return datetime.strptime(instant_str, "%Y-%m-%d") - except ValueError: - # Handle invalid dates like 2021-06-31 - # Try to parse year and month, then use last valid day - parts = instant_str.split("-") - if len(parts) == 3: - year = int(parts[0]) - month = int(parts[1]) - day = int(parts[2]) - - # Find the last valid day of the month - import calendar - - last_day = calendar.monthrange(year, month)[1] - if day > last_day: - print( - f"Warning: Invalid date {instant_str}, using {year}-{month:02d}-{last_day:02d}" - ) - return datetime(year, month, last_day) - - # If we can't parse it at all, print and raise - print(f"Error: Cannot parse date {instant_str}") - raise - - -class SeedObjects(BaseModel): - parameters: list["Parameter"] - baseline_parameter_values: list["BaselineParameterValue"] - baseline_variables: list["BaselineVariable"] diff --git a/src/policyengine/models/model_version.py b/src/policyengine/models/model_version.py deleted file mode 100644 index 18b542f8..00000000 --- a/src/policyengine/models/model_version.py +++ /dev/null @@ -1,14 +0,0 @@ -from datetime import datetime -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .model import Model - - -class ModelVersion(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - model: Model - version: str - description: str | None = None - created_at: datetime = Field(default_factory=datetime.now) diff --git a/src/policyengine/models/parameter.py b/src/policyengine/models/parameter.py index ec7ef7be..54e3e116 100644 --- a/src/policyengine/models/parameter.py +++ b/src/policyengine/models/parameter.py @@ -2,13 +2,13 @@ from pydantic import BaseModel, Field -from .model import Model +from .tax_benefit_model_version import TaxBenefitModelVersion class Parameter(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) + name: str description: str | None = None data_type: type | None = None - model: Model | None = None - label: str | None = None + tax_benefit_model_version: TaxBenefitModelVersion unit: str | None = None diff --git a/src/policyengine/models/policyengine_uk.py b/src/policyengine/models/policyengine_uk.py deleted file mode 100644 index 22c72546..00000000 --- a/src/policyengine/models/policyengine_uk.py +++ /dev/null @@ -1,117 +0,0 @@ -import importlib.metadata - -import pandas as pd - -from ..models import Dataset, Dynamic, Model, ModelVersion, Policy - - -def run_policyengine_uk( - dataset: "Dataset", - policy: "Policy | None" = None, - dynamic: "Dynamic | None" = None, -) -> dict[str, "pd.DataFrame"]: - data: dict[str, pd.DataFrame] = dataset.data - - from policyengine_uk import Microsimulation - from policyengine_uk.data import UKSingleYearDataset - - pe_input_data = UKSingleYearDataset( - person=data["person"], - benunit=data["benunit"], - household=data["household"], - fiscal_year=dataset.year, - ) - - sim = Microsimulation(dataset=pe_input_data) - sim.default_calculation_period = dataset.year - - def simulation_modifier(sim: Microsimulation): - if policy is not None and len(policy.parameter_values) > 0: - for parameter_value in policy.parameter_values: - if parameter_value.parameter is None: - raise ValueError(f"Parameter value {parameter_value.id} has no parameter set - the policy contains invalid data") - sim.tax_benefit_system.parameters.get_child( - parameter_value.parameter.id - ).update( - value=parameter_value.value, - start=parameter_value.start_date.strftime("%Y-%m-%d"), - stop=parameter_value.end_date.strftime("%Y-%m-%d") - if parameter_value.end_date - else None, - ) - - if dynamic is not None and len(dynamic.parameter_values) > 0: - for parameter_value in dynamic.parameter_values: - if parameter_value.parameter is None: - raise ValueError(f"Parameter value {parameter_value.id} has no parameter set - the dynamic contains invalid data") - sim.tax_benefit_system.parameters.get_child( - parameter_value.parameter.id - ).update( - value=parameter_value.value, - start=parameter_value.start_date.strftime("%Y-%m-%d"), - stop=parameter_value.end_date.strftime("%Y-%m-%d") - if parameter_value.end_date - else None, - ) - - if dynamic is not None and dynamic.simulation_modifier is not None: - dynamic.simulation_modifier(sim) - if policy is not None and policy.simulation_modifier is not None: - policy.simulation_modifier(sim) - - simulation_modifier(sim) - - output_data = {} - - variable_blacklist = [ # TEMPORARY: we need to fix policyengine-uk to make these only take a long time with non-default parameters set to true. - "is_uc_entitled_baseline", - "income_elasticity_lsr", - "child_benefit_opts_out", - "housing_benefit_baseline_entitlement", - "baseline_ctc_entitlement", - "pre_budget_change_household_tax", - "pre_budget_change_household_net_income", - "is_on_cliff", - "marginal_tax_rate_on_capital_gains", - "relative_capital_gains_mtr_change", - "pre_budget_change_ons_equivalised_income_decile", - "substitution_elasticity", - "marginal_tax_rate", - "cliff_evaluated", - "cliff_gap", - "substitution_elasticity_lsr", - "relative_wage_change", - "relative_income_change", - "pre_budget_change_household_benefits", - ] - - for entity in ["person", "benunit", "household"]: - output_data[entity] = pd.DataFrame() - for variable in sim.tax_benefit_system.variables.values(): - correct_entity = variable.entity.key == entity - if variable.name in variable_blacklist: - continue - if variable.definition_period != "year": - continue - if correct_entity: - output_data[entity][variable.name] = sim.calculate( - variable.name - ).values - output_data[entity] = pd.DataFrame(output_data[entity]) - - return output_data - - -policyengine_uk_model = Model( - id="policyengine_uk", - name="PolicyEngine UK", - description="PolicyEngine's open-source tax-benefit microsimulation model.", - simulation_function=run_policyengine_uk, -) - -# Get policyengine-uk version - -policyengine_uk_latest_version = ModelVersion( - model=policyengine_uk_model, - version=importlib.metadata.distribution("policyengine_uk").version, -) diff --git a/src/policyengine/models/policyengine_us.py b/src/policyengine/models/policyengine_us.py deleted file mode 100644 index 807859f0..00000000 --- a/src/policyengine/models/policyengine_us.py +++ /dev/null @@ -1,119 +0,0 @@ -import importlib.metadata - -import pandas as pd - -from ..models import Dataset, Dynamic, Model, ModelVersion, Policy - - -def run_policyengine_us( - dataset: "Dataset", - policy: "Policy | None" = None, - dynamic: "Dynamic | None" = None, -) -> dict[str, "pd.DataFrame"]: - data: dict[str, pd.DataFrame] = dataset.data - - person_df = pd.DataFrame() - - for table_name, table in data.items(): - if table_name == "person": - for col in table.columns: - person_df[f"{col}__{dataset.year}"] = table[col].values - else: - foreign_key = data["person"][f"person_{table_name}_id"] - primary_key = data[table_name][f"{table_name}_id"] - - projected = table.set_index(primary_key).loc[foreign_key] - - for col in projected.columns: - person_df[f"{col}__{dataset.year}"] = projected[col].values - - from policyengine_us import Microsimulation - - sim = Microsimulation(dataset=person_df) - sim.default_calculation_period = dataset.year - - def simulation_modifier(sim: Microsimulation): - if policy is not None and len(policy.parameter_values) > 0: - for parameter_value in policy.parameter_values: - if parameter_value.parameter is None: - raise ValueError(f"Parameter value {parameter_value.id} has no parameter set - the policy contains invalid data") - sim.tax_benefit_system.parameters.get_child( - parameter_value.parameter.id - ).update( - parameter_value.value, - start=parameter_value.start_date.strftime("%Y-%m-%d"), - stop=parameter_value.end_date.strftime("%Y-%m-%d") - if parameter_value.end_date - else None, - ) - - if dynamic is not None and len(dynamic.parameter_values) > 0: - for parameter_value in dynamic.parameter_values: - if parameter_value.parameter is None: - raise ValueError(f"Parameter value {parameter_value.id} has no parameter set - the dynamic contains invalid data") - sim.tax_benefit_system.parameters.get_child( - parameter_value.parameter.id - ).update( - parameter_value.value, - start=parameter_value.start_date.strftime("%Y-%m-%d"), - stop=parameter_value.end_date.strftime("%Y-%m-%d") - if parameter_value.end_date - else None, - ) - - if dynamic is not None and dynamic.simulation_modifier is not None: - dynamic.simulation_modifier(sim) - if policy is not None and policy.simulation_modifier is not None: - policy.simulation_modifier(sim) - - simulation_modifier(sim) - - # Skip reforms for now - - output_data = {} - - variable_whitelist = [ - "household_net_income", - ] - - for variable in variable_whitelist: - sim.calculate(variable) - - for entity in [ - "person", - "marital_unit", - "family", - "tax_unit", - "spm_unit", - "household", - ]: - output_data[entity] = pd.DataFrame() - for variable in sim.tax_benefit_system.variables.values(): - correct_entity = variable.entity.key == entity - if str(dataset.year) not in list( - map(str, sim.get_holder(variable.name).get_known_periods()) - ): - continue - if variable.definition_period != "year": - continue - if not correct_entity: - continue - output_data[entity][variable.name] = sim.calculate(variable.name).values - - return output_data - - -policyengine_us_model = Model( - id="policyengine_us", - name="PolicyEngine US", - description="PolicyEngine's open-source tax-benefit microsimulation model.", - simulation_function=run_policyengine_us, -) - -# Get policyengine-uk version - - -policyengine_us_latest_version = ModelVersion( - model=policyengine_us_model, - version=importlib.metadata.distribution("policyengine_us").version, -) diff --git a/src/policyengine/models/simulation.py b/src/policyengine/models/simulation.py index 7c0abad1..490a0ad2 100644 --- a/src/policyengine/models/simulation.py +++ b/src/policyengine/models/simulation.py @@ -6,8 +6,8 @@ from .dataset import Dataset from .dynamic import Dynamic -from .model import Model -from .model_version import ModelVersion +from .tax_benefit_model import TaxBenefitModel +from .tax_benefit_model_version import TaxBenefitModelVersion from .policy import Policy @@ -20,21 +20,6 @@ class Simulation(BaseModel): dynamic: Dynamic | None = None dataset: Dataset | None = None - model: Model | None = None - model_version: ModelVersion | None = None - result: Any | None = None - aggregates: list = Field(default_factory=list) # Will be list[Aggregate] but avoid circular import - - def run(self): - if not self.model: - raise ValueError("Cannot run simulation: model is not set") - if not self.dataset: - raise ValueError("Cannot run simulation: dataset is not set") - - self.result = self.model.simulation_function( - dataset=self.dataset, - policy=self.policy, - dynamic=self.dynamic, - ) - self.updated_at = datetime.now() - return self.result + tax_benefit_model: TaxBenefitModel | None = None + tax_benefit_model_version: TaxBenefitModelVersion | None = None + output_file_path: str | None = None diff --git a/src/policyengine/models/tax_benefit_model.py b/src/policyengine/models/tax_benefit_model.py new file mode 100644 index 00000000..5ad2ac84 --- /dev/null +++ b/src/policyengine/models/tax_benefit_model.py @@ -0,0 +1,15 @@ +from collections.abc import Callable +from datetime import datetime +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +if TYPE_CHECKING: + from .variable import Variable + from .parameter import Parameter + + +class TaxBenefitModel(BaseModel): + id: str + name: str + description: str | None = None diff --git a/build/lib/policyengine/models/model_version.py b/src/policyengine/models/tax_benefit_model_version.py similarity index 69% rename from build/lib/policyengine/models/model_version.py rename to src/policyengine/models/tax_benefit_model_version.py index 18b542f8..1c702b06 100644 --- a/build/lib/policyengine/models/model_version.py +++ b/src/policyengine/models/tax_benefit_model_version.py @@ -3,12 +3,12 @@ from pydantic import BaseModel, Field -from .model import Model +from .tax_benefit_model import TaxBenefitModel -class ModelVersion(BaseModel): +class TaxBenefitModelVersion(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) - model: Model + model: TaxBenefitModel version: str description: str | None = None created_at: datetime = Field(default_factory=datetime.now) diff --git a/src/policyengine/models/variable.py b/src/policyengine/models/variable.py new file mode 100644 index 00000000..a0b8b246 --- /dev/null +++ b/src/policyengine/models/variable.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from .tax_benefit_model_version import TaxBenefitModelVersion + + +class Variable(BaseModel): + id: str + tax_benefit_model_version: TaxBenefitModelVersion + entity: str + name: str | None = None + description: str | None = None + data_type: type = None diff --git a/src/policyengine/models/versioned_dataset.py b/src/policyengine/models/versioned_dataset.py deleted file mode 100644 index 2f5e14f7..00000000 --- a/src/policyengine/models/versioned_dataset.py +++ /dev/null @@ -1,12 +0,0 @@ -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .model import Model - - -class VersionedDataset(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - name: str - description: str - model: Model | None = None diff --git a/src/policyengine/utils/charts.py b/src/policyengine/utils/charts.py deleted file mode 100644 index 0cee7048..00000000 --- a/src/policyengine/utils/charts.py +++ /dev/null @@ -1,286 +0,0 @@ -"""Chart formatting utilities for PolicyEngine.""" - -import plotly.graph_objects as go -from IPython.display import HTML - -COLOUR_SCHEMES = { - "teal": { - "primary": "#319795", - "secondary": "#38B2AC", - "tertiary": "#4FD1C5", - "light": "#81E6D9", - "lighter": "#B2F5EA", - "lightest": "#E6FFFA", - "dark": "#2C7A7B", - "darker": "#285E61", - "darkest": "#234E52", - }, - "blue": { - "primary": "#0EA5E9", - "secondary": "#0284C7", - "tertiary": "#38BDF8", - "light": "#7DD3FC", - "lighter": "#BAE6FD", - "lightest": "#E0F2FE", - "dark": "#026AA2", - "darker": "#075985", - "darkest": "#0C4A6E", - }, - "gray": { - "primary": "#6B7280", - "secondary": "#9CA3AF", - "tertiary": "#D1D5DB", - "light": "#E2E8F0", - "lighter": "#F2F4F7", - "lightest": "#F9FAFB", - "dark": "#4B5563", - "darker": "#344054", - "darkest": "#101828", - }, -} - -DEFAULT_COLOURS = [ - COLOUR_SCHEMES["teal"]["primary"], - COLOUR_SCHEMES["blue"]["primary"], - COLOUR_SCHEMES["teal"]["secondary"], - COLOUR_SCHEMES["blue"]["secondary"], - COLOUR_SCHEMES["teal"]["tertiary"], - COLOUR_SCHEMES["blue"]["tertiary"], - COLOUR_SCHEMES["gray"]["dark"], - COLOUR_SCHEMES["teal"]["dark"], -] - - -def add_fonts() -> HTML: - """Return HTML to add Google Fonts for Roboto and Roboto Mono.""" - return HTML(""" - - - - """) - - -def format_figure( - fig: go.Figure, - title: str | None = None, - x_title: str | None = None, - y_title: str | None = None, - colour_scheme: str = "teal", - show_grid: bool = True, - show_legend: bool = True, - height: int | None = None, - width: int | None = None, -) -> go.Figure: - """Apply consistent formatting to a Plotly figure. - - Args: - fig: The Plotly figure to format - title: Optional title for the chart - x_title: Optional x-axis title - y_title: Optional y-axis title - colour_scheme: Colour scheme name (teal, blue, gray) - show_grid: Whether to show gridlines - show_legend: Whether to show the legend - height: Optional figure height in pixels - width: Optional figure width in pixels - - Returns: - The formatted figure - """ - - colours = COLOUR_SCHEMES.get(colour_scheme, COLOUR_SCHEMES["teal"]) - - # Update traces with colour scheme - for i, trace in enumerate(fig.data): - if hasattr(trace, "marker"): - trace.marker.color = DEFAULT_COLOURS[i % len(DEFAULT_COLOURS)] - if hasattr(trace, "line"): - trace.line.color = DEFAULT_COLOURS[i % len(DEFAULT_COLOURS)] - trace.line.width = 2 - - # Base layout settings - layout_updates = { - "font": { - "family": "Roboto, sans-serif", - "size": 14, - "color": COLOUR_SCHEMES["gray"]["darkest"], - }, - "plot_bgcolor": "white", - "paper_bgcolor": "white", - "showlegend": show_legend, - "hovermode": "x unified", - "hoverlabel": { - "bgcolor": "white", - "font": {"family": "Roboto Mono, monospace", "size": 12}, - "bordercolor": colours["light"], - }, - } - - # Add title if provided - if title: - layout_updates["title"] = { - "text": title, - "font": { - "family": "Roboto, sans-serif", - "size": 20, - "color": COLOUR_SCHEMES["gray"]["darkest"], - "weight": 500, - }, - } - - # Configure axes - axis_config = { - "showgrid": show_grid, - "gridcolor": COLOUR_SCHEMES["gray"]["light"], - "gridwidth": 1, - "zeroline": True, - "zerolinecolor": COLOUR_SCHEMES["gray"]["lighter"], - "zerolinewidth": 1, - "tickfont": { - "family": "Roboto Mono, monospace", - "size": 11, - "color": COLOUR_SCHEMES["gray"]["primary"], - }, - "titlefont": { - "family": "Roboto, sans-serif", - "size": 14, - "color": COLOUR_SCHEMES["gray"]["dark"], - }, - "linecolor": COLOUR_SCHEMES["gray"]["light"], - "linewidth": 1, - "showline": True, - "mirror": False, - } - - layout_updates["xaxis"] = axis_config.copy() - layout_updates["yaxis"] = axis_config.copy() - - if x_title: - layout_updates["xaxis"]["title"] = x_title - if y_title: - layout_updates["yaxis"]["title"] = y_title - - layout_updates["showlegend"] = len(fig.data) > 1 and show_legend - - # Set dimensions if provided - if height: - layout_updates["height"] = height - if width: - layout_updates["width"] = width - - fig.update_layout(**layout_updates) - - fig.update_xaxes(title_font_color=COLOUR_SCHEMES["gray"]["primary"]) - fig.update_yaxes(title_font_color=COLOUR_SCHEMES["gray"]["primary"]) - - # Add text annotations to bars in bar charts - if any(isinstance(trace, go.Bar) for trace in fig.data): - for trace in fig.data: - if isinstance(trace, go.Bar): - trace.texttemplate = "%{y:,.0f}" - trace.textposition = "outside" - trace.textfont = { - "family": "Roboto Mono, monospace", - "size": 11, - "color": COLOUR_SCHEMES["gray"]["primary"], - } - - return fig - - -def create_bar_chart( - data: dict[str, list], - x: str, - y: str, - title: str | None = None, - colour_scheme: str = "teal", - **kwargs, -) -> go.Figure: - """Create a formatted bar chart. - - Args: - data: Dictionary with data for the chart - x: Column name for x-axis - y: Column name for y-axis - title: Optional chart title - colour_scheme: Colour scheme to use - **kwargs: Additional arguments for format_figure - - Returns: - Formatted bar chart figure - """ - fig = go.Figure( - data=[ - go.Bar( - x=data[x], - y=data[y], - marker_color=COLOUR_SCHEMES[colour_scheme]["primary"], - marker_line_color=COLOUR_SCHEMES[colour_scheme]["dark"], - marker_line_width=1, - hovertemplate=f"{x}: " - + "%{x}
" - + f"{y}: " - + "%{y:,.0f}", - ) - ] - ) - - return format_figure( - fig, - title=title, - x_title=x, - y_title=y, - colour_scheme=colour_scheme, - **kwargs, - ) - - -def create_line_chart( - data: dict[str, list], - x: str, - y: str | list[str], - title: str | None = None, - colour_scheme: str = "teal", - **kwargs, -) -> go.Figure: - """Create a formatted line chart. - - Args: - data: Dictionary with data for the chart - x: Column name for x-axis - y: Column name(s) for y-axis (can be a list for multiple lines) - title: Optional chart title - colour_scheme: Colour scheme to use - **kwargs: Additional arguments for format_figure - - Returns: - Formatted line chart figure - """ - traces = [] - y_columns = y if isinstance(y, list) else [y] - - for i, y_col in enumerate(y_columns): - traces.append( - go.Scatter( - x=data[x], - y=data[y_col], - mode="lines+markers", - name=y_col, - line=dict( - color=DEFAULT_COLOURS[i % len(DEFAULT_COLOURS)], width=2 - ), - marker=dict(size=6), - hovertemplate=f"{y_col}: " + "%{y:,.0f}", - ) - ) - - fig = go.Figure(data=traces) - - return format_figure( - fig, - title=title, - x_title=x, - y_title=y_columns[0] if len(y_columns) == 1 else None, - colour_scheme=colour_scheme, - **kwargs, - ) diff --git a/src/policyengine/utils/compress.py b/src/policyengine/utils/compress.py deleted file mode 100644 index 19180e2a..00000000 --- a/src/policyengine/utils/compress.py +++ /dev/null @@ -1,20 +0,0 @@ -import pickle -from typing import Any - -import blosc - - -def compress_data(data: Any) -> bytes: - """Compress data using blosc after pickling.""" - pickled_data = pickle.dumps(data) - compressed_data = blosc.compress( - pickled_data, typesize=8, cname="zstd", clevel=9, shuffle=blosc.SHUFFLE - ) - return compressed_data - - -def decompress_data(compressed_data: bytes) -> Any: - """Decompress data using blosc and then unpickle.""" - decompressed_data = blosc.decompress(compressed_data) - data = pickle.loads(decompressed_data) - return data diff --git a/src/policyengine/utils/datasets.py b/src/policyengine/utils/datasets.py deleted file mode 100644 index 02090e11..00000000 --- a/src/policyengine/utils/datasets.py +++ /dev/null @@ -1,71 +0,0 @@ -import pandas as pd - -from policyengine.models import Dataset - - -def create_uk_dataset( - dataset: str = "enhanced_frs_2023_24.h5", - year: int = 2029, -): - from policyengine_uk import Microsimulation - - from policyengine.models.policyengine_uk import policyengine_uk_model - - sim = Microsimulation( - dataset="hf://policyengine/policyengine-uk-data/" + dataset - ) - sim.default_calculation_period = year - - tables = { - "person": pd.DataFrame(sim.dataset[year].person), - "benunit": pd.DataFrame(sim.dataset[year].benunit), - "household": pd.DataFrame(sim.dataset[year].household), - } - - return Dataset( - id="uk", - name="UK", - description="A representative dataset for the UK, based on the Family Resources Survey.", - year=year, - model=policyengine_uk_model, - data=tables, - ) - - -def create_us_dataset( - dataset: str = "enhanced_cps_2024.h5", - year: int = 2024, -): - from policyengine_us import Microsimulation - - from policyengine.models.policyengine_us import policyengine_us_model - - sim = Microsimulation( - dataset="hf://policyengine/policyengine-us-data/" + dataset - ) - sim.default_calculation_period = year - - known_variables = sim.input_variables - - tables = { - "person": pd.DataFrame(), - "marital_unit": pd.DataFrame(), - "tax_unit": pd.DataFrame(), - "spm_unit": pd.DataFrame(), - "family": pd.DataFrame(), - "household": pd.DataFrame(), - } - - for variable in known_variables: - entity = sim.tax_benefit_system.variables[variable].entity.key - if variable in sim.tax_benefit_system.variables: - tables[entity][variable] = sim.calculate(variable) - - return Dataset( - id="us", - name="US", - description="A representative dataset for the US, based on the Current Population Survey.", - year=year, - model=policyengine_us_model, - data=tables, - ) diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py deleted file mode 100644 index d956e1c8..00000000 --- a/tests/test_aggregate.py +++ /dev/null @@ -1,451 +0,0 @@ -""" -Tests for the clean aggregate implementation. - -Tests cover: -- Basic aggregations (sum, mean, median, count) -- Filtering (value, range, quantile) -- Cross-entity queries -- Batching efficiency -- Edge cases -""" - -import pytest -import pandas as pd - -from policyengine.models.aggregate import Aggregate, AggregateType - - -class MockSimulation: - """Mock simulation for testing.""" - - def __init__(self, result, year=2024): - self.result = result - self.dataset = MockDataset(year) - - -class MockDataset: - def __init__(self, year): - self.year = year - - -@pytest.fixture -def sample_tables(): - """Create sample person/household tables for testing.""" - person = pd.DataFrame({ - 'person_id': [0, 1, 2, 3], - 'person_household_id': [0, 0, 1, 1], - 'person_weight': [100.0, 100.0, 200.0, 200.0], - 'age': [30, 5, 45, 40], - 'employment_income': [50000, 0, 60000, 55000], - }) - - household = pd.DataFrame({ - 'household_id': [0, 1], - 'household_weight': [100.0, 200.0], - 'household_net_income': [50000, 115000], - 'is_in_poverty': [1, 0], - }) - - return {'person': person, 'household': household} - - -class TestBasicAggregations: - """Test basic aggregation functions.""" - - def test_sum(self, sample_tables): - """Test sum aggregation.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.SUM, - entity='person' - ) - results = Aggregate.run([agg]) - # Weighted sum: 50000*100 + 0*100 + 60000*200 + 55000*200 = 28,000,000 - assert results[0].value == 28_000_000.0 - - def test_mean(self, sample_tables): - """Test mean aggregation.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='age', - aggregate_function=AggregateType.MEAN, - entity='person' - ) - results = Aggregate.run([agg]) - # Weighted mean: (30*100 + 5*100 + 45*200 + 40*200) / 600 = 34.17 - assert round(results[0].value, 2) == 34.17 - - def test_count(self, sample_tables): - """Test count aggregation.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person' - ) - results = Aggregate.run([agg]) - # Weighted count: sum of person weights = 100 + 100 + 200 + 200 = 600 - assert results[0].value == 600.0 - - def test_median(self, sample_tables): - """Test median aggregation.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='age', - aggregate_function=AggregateType.MEDIAN, - entity='person' - ) - results = Aggregate.run([agg]) - assert results[0].value > 0 - - def test_entity_inference(self, sample_tables): - """Test that entity is inferred correctly.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.SUM - # entity not specified - ) - results = Aggregate.run([agg]) - assert results[0].entity == 'person' - assert results[0].value == 28_000_000.0 - - -class TestFiltering: - """Test filtering functionality.""" - - def test_value_filter(self, sample_tables): - """Test filtering with exact value match.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='age', - filter_variable_value=30 - ) - results = Aggregate.run([agg]) - # Weighted count: person 0 has age 30 and weight 100 - assert results[0].value == 100.0 - - def test_range_filter_leq(self, sample_tables): - """Test filtering with <= operator.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.SUM, - entity='person', - filter_variable_name='age', - filter_variable_leq=35 - ) - results = Aggregate.run([agg]) - # Persons with age <= 35: person 0 (age 30) and person 1 (age 5) - # Weighted sum: 50000*100 + 0*100 = 5,000,000 - assert results[0].value == 5_000_000.0 - - def test_range_filter_geq(self, sample_tables): - """Test filtering with >= operator.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.SUM, - entity='person', - filter_variable_name='age', - filter_variable_geq=40 - ) - results = Aggregate.run([agg]) - # Persons with age >= 40: person 2 (age 45) and person 3 (age 40) - # Weighted sum: 60000*200 + 55000*200 = 23,000,000 - assert results[0].value == 23_000_000.0 - - def test_combined_range_filters(self, sample_tables): - """Test combining leq and geq filters.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.SUM, - entity='person', - filter_variable_name='age', - filter_variable_geq=18, - filter_variable_leq=35 - ) - results = Aggregate.run([agg]) - # Person 0: age 30, income 50000, weight 100 - assert results[0].value == 5_000_000.0 - - def test_quantile_filter_leq(self, sample_tables): - """Test filtering with quantile_leq.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='age', - filter_variable_quantile_leq=0.5 - ) - results = Aggregate.run([agg]) - # Weighted median age is 40, so includes ages <= 40: persons 0, 1, 3 - # Weighted count: 100 + 100 + 200 = 400 - assert results[0].value == 400.0 - - def test_quantile_filter_geq(self, sample_tables): - """Test filtering with quantile_geq.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='age', - filter_variable_quantile_geq=0.5 - ) - results = Aggregate.run([agg]) - # Weighted median age is 40, so includes ages >= 40: persons 2, 3 - # Weighted count: 200 + 200 = 400 - assert results[0].value == 400.0 - - -class TestCrossEntity: - """Test cross-entity queries.""" - - def test_household_filter_on_person_aggregation(self, sample_tables): - """Test filtering persons by household variable.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='is_in_poverty', - filter_variable_value=1 - ) - results = Aggregate.run([agg]) - # Persons in poor households (household 0): persons 0 and 1 - # Weighted count: 100 + 100 = 200 - assert results[0].value == 200.0 - - def test_person_to_household_aggregation(self, sample_tables): - """Test aggregating person variable at household level.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.SUM, - entity='household' - ) - results = Aggregate.run([agg]) - # Employment income summed to household level with household weights: - # Household 0: (50000 + 0) * 100 = 5,000,000 - # Household 1: (60000 + 55000) * 200 = 23,000,000 - # Total weighted sum: 28,000,000 - assert results[0].value == 28_000_000.0 - - def test_poverty_rate_calculation(self, sample_tables): - """Test calculating poverty rate.""" - sim = MockSimulation(sample_tables) - - # Count persons in poverty - poor = Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='is_in_poverty', - filter_variable_value=1 - ) - - # Total persons - total = Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person' - ) - - results = Aggregate.run([poor, total]) - poverty_rate = results[0].value / results[1].value - # Weighted: 200 poor / 600 total = 1/3 - assert round(poverty_rate, 3) == 0.333 - - def test_mean_income_for_poor(self, sample_tables): - """Test mean income for persons in poor households.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.MEAN, - entity='person', - filter_variable_name='is_in_poverty', - filter_variable_value=1 - ) - results = Aggregate.run([agg]) - # Persons in poverty: person 0 (income 50000, weight 100), person 1 (income 0, weight 100) - # Weighted mean: (50000*100 + 0*100) / 200 = 25000 - assert results[0].value == 25000.0 - - -class TestBatching: - """Test batch processing efficiency.""" - - def test_batch_same_simulation(self, sample_tables): - """Test that aggregates with same simulation are batched.""" - sim = MockSimulation(sample_tables) - - aggregates = [ - Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.SUM, - entity='person' - ), - Aggregate( - simulation=sim, - variable_name='age', - aggregate_function=AggregateType.MEAN, - entity='person' - ), - Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person' - ), - ] - - results = Aggregate.run(aggregates) - assert len(results) == 3 - assert results[0].value == 28_000_000.0 - assert round(results[1].value, 2) == 34.17 - assert results[2].value == 600.0 # Weighted count - - def test_batch_different_filters(self, sample_tables): - """Test batching aggregates with different filters.""" - sim = MockSimulation(sample_tables) - - aggregates = [ - Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='age', - filter_variable_leq=17 - ), - Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='age', - filter_variable_geq=18 - ), - ] - - results = Aggregate.run(aggregates) - assert len(results) == 2 - assert results[0].value == 100.0 # Children: person 1 weight 100 - assert results[1].value == 500.0 # Adults: persons 0,2,3 weights 100+200+200 - - -class TestEdgeCases: - """Test edge cases.""" - - def test_empty_result(self, sample_tables): - """Test filtering that results in empty set.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='age', - filter_variable_value=999 - ) - results = Aggregate.run([agg]) - assert results[0].value == 0.0 - - def test_weight_column_sum(self, sample_tables): - """Test that weight columns avoid double-weighting.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='person_weight', - aggregate_function=AggregateType.SUM, - entity='person' - ) - results = Aggregate.run([agg]) - # Simple sum (not weighted): 100 + 100 + 200 + 200 = 600 - assert results[0].value == 600.0 - - def test_missing_variable(self, sample_tables): - """Test error when variable doesn't exist.""" - sim = MockSimulation(sample_tables) - agg = Aggregate( - simulation=sim, - variable_name='nonexistent', - aggregate_function=AggregateType.SUM - ) - with pytest.raises(ValueError, match='not found'): - Aggregate.run([agg]) - - -class TestComplexScenarios: - """Test complex real-world scenarios.""" - - def test_poverty_by_age_group(self, sample_tables): - """Test poverty analysis by age group.""" - sim = MockSimulation(sample_tables) - - # Children in poverty - children_poor = Aggregate( - simulation=sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='age', - filter_variable_leq=17 - ) - - results = Aggregate.run([children_poor]) - assert results[0].value == 100.0 # Person 1 (age 5, weight 100) - - def test_multiple_aggregations(self, sample_tables): - """Test running multiple different aggregations together.""" - sim = MockSimulation(sample_tables) - - aggs = [ - Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.SUM, - entity='person' - ), - Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.MEAN, - entity='person' - ), - Aggregate( - simulation=sim, - variable_name='employment_income', - aggregate_function=AggregateType.MEDIAN, - entity='person' - ), - ] - - results = Aggregate.run(aggs) - assert len(results) == 3 - assert results[0].value > results[1].value > 0 - assert results[2].value > 0 diff --git a/tests/test_aggregate_change.py b/tests/test_aggregate_change.py deleted file mode 100644 index 11938a5d..00000000 --- a/tests/test_aggregate_change.py +++ /dev/null @@ -1,479 +0,0 @@ -""" -Tests for the clean AggregateChange implementation. - -Tests cover: -- Basic change calculations -- Relative change calculations -- Cross-entity filters -- Batching multiple changes -- Edge cases -""" - -import pytest -import pandas as pd - -from policyengine.models.aggregate_change import AggregateChange -from policyengine.models.aggregate import AggregateType - - -class MockSimulation: - """Mock simulation for testing.""" - - def __init__(self, result, year=2024, sim_id=None): - self.result = result - self.dataset = MockDataset(year) - self.id = sim_id or "sim_123" - - -class MockDataset: - def __init__(self, year): - self.year = year - - -@pytest.fixture -def baseline_tables(): - """Baseline simulation tables.""" - person = pd.DataFrame({ - 'person_id': [0, 1, 2, 3], - 'person_household_id': [0, 0, 1, 1], - 'person_weight': [100.0, 100.0, 200.0, 200.0], - 'age': [30, 5, 45, 40], - 'employment_income': [50000, 0, 60000, 55000], - 'benefits': [5000, 2000, 0, 0], - }) - - household = pd.DataFrame({ - 'household_id': [0, 1], - 'household_weight': [100.0, 200.0], - 'household_net_income': [57000, 115000], - 'is_in_poverty': [1, 0], - }) - - return {'person': person, 'household': household} - - -@pytest.fixture -def comparison_tables(): - """Comparison simulation tables (with policy change).""" - person = pd.DataFrame({ - 'person_id': [0, 1, 2, 3], - 'person_household_id': [0, 0, 1, 1], - 'person_weight': [100.0, 100.0, 200.0, 200.0], - 'age': [30, 5, 45, 40], - 'employment_income': [50000, 0, 60000, 55000], - 'benefits': [8000, 3000, 1000, 1000], # Benefits increased - }) - - household = pd.DataFrame({ - 'household_id': [0, 1], - 'household_weight': [100.0, 200.0], - 'household_net_income': [61000, 117000], # Incomes increased - 'is_in_poverty': [0, 0], # Household 0 lifted out of poverty - }) - - return {'person': person, 'household': household} - - -class TestBasicChanges: - """Test basic change calculations.""" - - def test_simple_change(self, baseline_tables, comparison_tables): - """Test calculating a simple change in totals.""" - baseline_sim = MockSimulation(baseline_tables) - comparison_sim = MockSimulation(comparison_tables) - - agg_change = AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='benefits', - aggregate_function=AggregateType.SUM, - entity='person' - ) - - results = AggregateChange.run([agg_change]) - result = results[0] - - # Baseline weighted sum: 5000*100 + 2000*100 + 0*200 + 0*200 = 700,000 - assert result.baseline_value == 700_000.0 - - # Comparison weighted sum: 8000*100 + 3000*100 + 1000*200 + 1000*200 = 1,500,000 - assert result.comparison_value == 1_500_000.0 - - # Change: 1,500,000 - 700,000 = 800,000 - assert result.change == 800_000.0 - - # Relative change: 800,000 / 700,000 ≈ 1.14 - assert round(result.relative_change, 2) == 1.14 - - def test_mean_change(self, baseline_tables, comparison_tables): - """Test calculating change in mean values.""" - baseline_sim = MockSimulation(baseline_tables) - comparison_sim = MockSimulation(comparison_tables) - - agg_change = AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='benefits', - aggregate_function=AggregateType.MEAN, - entity='person' - ) - - results = AggregateChange.run([agg_change]) - result = results[0] - - # Baseline weighted mean: 700,000 / 600 = 1,166.67 - assert round(result.baseline_value, 2) == 1166.67 - - # Comparison weighted mean: 1,500,000 / 600 = 2,500 - assert result.comparison_value == 2500.0 - - # Change: 2,500 - 1,166.67 = 1,333.33 - assert round(result.change, 2) == 1333.33 - - def test_count_change(self, baseline_tables, comparison_tables): - """Test change in counts (e.g., poverty count).""" - baseline_sim = MockSimulation(baseline_tables) - comparison_sim = MockSimulation(comparison_tables) - - # Count households in poverty - agg_change = AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='household_id', - aggregate_function=AggregateType.COUNT, - entity='household', - filter_variable_name='is_in_poverty', - filter_variable_value=1 - ) - - results = AggregateChange.run([agg_change]) - result = results[0] - - # Baseline: household 0 in poverty with weight 100 - assert result.baseline_value == 100.0 - - # Comparison: 0 households in poverty - assert result.comparison_value == 0.0 - - # Change: -100 (weighted household count) - assert result.change == -100.0 - - -class TestCrossEntityChanges: - """Test changes with cross-entity filters.""" - - def test_persons_in_poverty_change(self, baseline_tables, comparison_tables): - """Test change in count of persons in poor households.""" - baseline_sim = MockSimulation(baseline_tables) - comparison_sim = MockSimulation(comparison_tables) - - # Count persons in poor households - agg_change = AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='is_in_poverty', - filter_variable_value=1 - ) - - results = AggregateChange.run([agg_change]) - result = results[0] - - # Baseline: persons 0, 1 in poor households with weights 100 + 100 = 200 - assert result.baseline_value == 200.0 - - # Comparison: 0 persons in poor households - assert result.comparison_value == 0.0 - - # Change: -200 (weighted person count) - assert result.change == -200.0 - - def test_mean_benefits_for_poor(self, baseline_tables, comparison_tables): - """Test change in mean benefits for persons in poor households.""" - baseline_sim = MockSimulation(baseline_tables) - comparison_sim = MockSimulation(comparison_tables) - - agg_change = AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='benefits', - aggregate_function=AggregateType.MEAN, - entity='person', - filter_variable_name='is_in_poverty', - filter_variable_value=1 - ) - - results = AggregateChange.run([agg_change]) - result = results[0] - - # Baseline: persons 0 and 1 in poverty - # Weighted mean: (5000*100 + 2000*100) / 200 = 3,500 - assert result.baseline_value == 3500.0 - - # Comparison: 0 persons in poverty (empty filter) - assert result.comparison_value == 0.0 - - -class TestBatching: - """Test efficient batching of multiple changes.""" - - def test_batch_multiple_changes(self, baseline_tables, comparison_tables): - """Test processing multiple aggregate changes efficiently.""" - baseline_sim = MockSimulation(baseline_tables) - comparison_sim = MockSimulation(comparison_tables) - - changes = [ - AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='benefits', - aggregate_function=AggregateType.SUM, - entity='person' - ), - AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='employment_income', - aggregate_function=AggregateType.MEAN, - entity='person' - ), - AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='is_in_poverty', - filter_variable_value=1 - ), - ] - - results = AggregateChange.run(changes) - - assert len(results) == 3 - assert results[0].change == 800_000.0 # Benefits increased - assert results[1].change == 0.0 # Employment income unchanged - assert results[2].change == -200.0 # Poverty count decreased (weighted) - - -class TestRangeFilters: - """Test aggregate changes with range filters.""" - - def test_change_with_age_filter(self, baseline_tables, comparison_tables): - """Test change in benefits for specific age group.""" - baseline_sim = MockSimulation(baseline_tables) - comparison_sim = MockSimulation(comparison_tables) - - # Benefits for children (age < 18) - agg_change = AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='benefits', - aggregate_function=AggregateType.SUM, - entity='person', - filter_variable_name='age', - filter_variable_leq=17 - ) - - results = AggregateChange.run([agg_change]) - result = results[0] - - # Person 1 (age 5): baseline 2000*100, comparison 3000*100 - assert result.baseline_value == 200_000.0 - assert result.comparison_value == 300_000.0 - assert result.change == 100_000.0 - - def test_change_with_quantile_filter(self, baseline_tables, comparison_tables): - """Test change for income quantiles.""" - baseline_sim = MockSimulation(baseline_tables) - comparison_sim = MockSimulation(comparison_tables) - - # Benefits for bottom 50% by income - agg_change = AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='benefits', - aggregate_function=AggregateType.MEAN, - entity='person', - filter_variable_name='employment_income', - filter_variable_quantile_leq=0.5 - ) - - results = AggregateChange.run([agg_change]) - result = results[0] - - # Should get results for lower-income persons - assert result.baseline_value >= 0 - assert result.comparison_value >= 0 - - -class TestEdgeCases: - """Test edge cases.""" - - def test_zero_baseline_value(self, baseline_tables, comparison_tables): - """Test relative change when baseline is zero.""" - # Create tables where baseline has zero value - baseline_zero = { - 'person': pd.DataFrame({ - 'person_id': [0, 1], - 'person_weight': [1.0, 1.0], - 'new_benefit': [0, 0] - }) - } - - comparison_nonzero = { - 'person': pd.DataFrame({ - 'person_id': [0, 1], - 'person_weight': [1.0, 1.0], - 'new_benefit': [1000, 1000] - }) - } - - baseline_sim = MockSimulation(baseline_zero) - comparison_sim = MockSimulation(comparison_nonzero) - - agg_change = AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='new_benefit', - aggregate_function=AggregateType.SUM, - entity='person' - ) - - results = AggregateChange.run([agg_change]) - result = results[0] - - assert result.baseline_value == 0.0 - assert result.comparison_value == 2000.0 - assert result.change == 2000.0 - assert result.relative_change == float('inf') - - def test_both_zero(self): - """Test when both baseline and comparison are zero.""" - baseline = { - 'person': pd.DataFrame({ - 'person_id': [0, 1], - 'person_weight': [1.0, 1.0], - 'value': [0, 0] - }) - } - - baseline_sim = MockSimulation(baseline) - comparison_sim = MockSimulation(baseline) - - agg_change = AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='value', - aggregate_function=AggregateType.SUM, - entity='person' - ) - - results = AggregateChange.run([agg_change]) - result = results[0] - - assert result.baseline_value == 0.0 - assert result.comparison_value == 0.0 - assert result.change == 0.0 - assert result.relative_change is None - - def test_missing_simulation(self): - """Test error when simulation is missing.""" - agg_change = AggregateChange( - variable_name='value', - aggregate_function=AggregateType.SUM - ) - - with pytest.raises(ValueError, match='missing baseline_simulation'): - AggregateChange.run([agg_change]) - - def test_negative_change(self, baseline_tables, comparison_tables): - """Test calculating negative changes correctly.""" - # Create scenario where value decreases - baseline_high = { - 'person': pd.DataFrame({ - 'person_id': [0, 1], - 'person_weight': [1.0, 1.0], - 'value': [1000, 1000] - }) - } - - comparison_low = { - 'person': pd.DataFrame({ - 'person_id': [0, 1], - 'person_weight': [1.0, 1.0], - 'value': [500, 500] - }) - } - - baseline_sim = MockSimulation(baseline_high) - comparison_sim = MockSimulation(comparison_low) - - agg_change = AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='value', - aggregate_function=AggregateType.SUM, - entity='person' - ) - - results = AggregateChange.run([agg_change]) - result = results[0] - - assert result.baseline_value == 2000.0 - assert result.comparison_value == 1000.0 - assert result.change == -1000.0 - assert result.relative_change == -0.5 - - -class TestRealWorldScenarios: - """Test realistic policy analysis scenarios.""" - - def test_poverty_impact_analysis(self, baseline_tables, comparison_tables): - """Test complete poverty impact analysis.""" - baseline_sim = MockSimulation(baseline_tables) - comparison_sim = MockSimulation(comparison_tables) - - analysis = [ - # Total poverty count - AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='household_id', - aggregate_function=AggregateType.COUNT, - entity='household', - filter_variable_name='is_in_poverty', - filter_variable_value=1 - ), - # Persons in poverty - AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='person_id', - aggregate_function=AggregateType.COUNT, - entity='person', - filter_variable_name='is_in_poverty', - filter_variable_value=1 - ), - # Mean benefits for poor households - AggregateChange( - baseline_simulation=baseline_sim, - comparison_simulation=comparison_sim, - variable_name='benefits', - aggregate_function=AggregateType.MEAN, - entity='person', - filter_variable_name='is_in_poverty', - filter_variable_value=1 - ), - ] - - results = AggregateChange.run(analysis) - - # Poverty decreased - assert results[0].change < 0 # Fewer poor households - assert results[1].change < 0 # Fewer poor persons - - # Benefits increased for those who were poor - assert results[2].baseline_value > 0 diff --git a/tests/test_database_init.py b/tests/test_database_init.py deleted file mode 100644 index b343e91e..00000000 --- a/tests/test_database_init.py +++ /dev/null @@ -1,259 +0,0 @@ -"""Test database initialization and table creation.""" - -import sys -from pathlib import Path - -import pytest -from sqlalchemy import inspect - -# Add src to path to allow imports -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - - -# Module-level functions for testing (can be pickled) -# Note: Don't use test_ prefix or pytest will try to run them -def sim_func(x): - """Simulation function that can be pickled.""" - return x - - -def sim_func_double(x): - """Simulation function that doubles input.""" - return x * 2 - - -@pytest.fixture -def fresh_database(): - """Create a fresh database instance for each test.""" - from policyengine.database import Database - - # Use in-memory SQLite for testing - db = Database(url="sqlite:///:memory:") - return db - - -def test_database_creates_engine(fresh_database): - """Test that database initialization creates an engine.""" - assert fresh_database.engine is not None - assert fresh_database.url == "sqlite:///:memory:" - - -def test_database_creates_session(fresh_database): - """Test that database initialization creates a session.""" - assert fresh_database.session is not None - - -def test_create_tables_creates_all_registered_tables(fresh_database): - """Test that create_tables() creates all registered tables.""" - fresh_database.create_tables() - - # Get inspector to check actual database tables - inspector = inspect(fresh_database.engine) - actual_tables = inspector.get_table_names() - - # Expected tables based on registered table links - expected_tables = { - "models", - "model_versions", - "datasets", - "versioned_datasets", - "policies", - "dynamics", - "parameters", - "parameter_values", - "baseline_parameter_values", - "baseline_variables", - "simulations", - "aggregates", - } - - # Check that all expected tables exist - for table in expected_tables: - assert table in actual_tables, f"Table {table} was not created" - - -def test_register_table_creates_table(fresh_database): - """Test that register_table registers the table link.""" - # Tables should be registered but NOT created until create_tables() is called - inspector = inspect(fresh_database.engine) - initial_tables = set(inspector.get_table_names()) - - # No tables should exist yet - assert len(initial_tables) == 0 - - # But table links should be registered - assert len(fresh_database._model_table_links) == 12 - - # After calling create_tables(), tables should exist - fresh_database.create_tables() - inspector = inspect(fresh_database.engine) - tables_after = set(inspector.get_table_names()) - assert "models" in tables_after - - -def test_reset_drops_and_recreates_tables(fresh_database): - """Test that reset() drops and recreates all tables.""" - from policyengine.models import Model - - # Create tables first - fresh_database.create_tables() - - # Add some data - model = Model( - id="test_model", - name="Test", - description="Test model", - simulation_function=sim_func, - ) - fresh_database.set(model) - - # Verify data exists - retrieved = fresh_database.get(Model, id="test_model") - assert retrieved is not None - - # Reset the database - fresh_database.reset() - - # Tables should exist but be empty - inspector = inspect(fresh_database.engine) - tables = inspector.get_table_names() - assert "models" in tables - - # Data should be gone - retrieved_after_reset = fresh_database.get(Model, id="test_model") - assert retrieved_after_reset is None - - -def test_drop_tables_removes_all_tables(fresh_database): - """Test that drop_tables() removes all tables.""" - # Create tables first - fresh_database.create_tables() - - # Verify tables exist - inspector = inspect(fresh_database.engine) - tables_before = inspector.get_table_names() - assert len(tables_before) > 0 - - # Drop tables - fresh_database.drop_tables() - - # Verify tables are gone - inspector = inspect(fresh_database.engine) - tables_after = inspector.get_table_names() - assert len(tables_after) == 0 - - -def test_context_manager_commits_on_success(): - """Test that context manager commits on successful operations.""" - from policyengine.database import Database - from policyengine.models import Model - - db = Database(url="sqlite:///:memory:") - db.create_tables() - - # Use context manager to add data - with db as session: - model = Model( - id="test_context_model", - name="Context Test", - description="Testing context manager", - simulation_function=sim_func, - ) - db.set(model, commit=False) # Don't commit inside set - - # Data should be committed after context exit - retrieved = db.get(Model, id="test_context_model") - assert retrieved is not None - assert retrieved.name == "Context Test" - - -def test_context_manager_rolls_back_on_error(): - """Test that context manager rolls back on errors.""" - from policyengine.database import Database - from policyengine.models import Model - - db = Database(url="sqlite:///:memory:") - db.create_tables() - - # Try to use context manager with an error - try: - with db as session: - model = Model( - id="test_rollback_model", - name="Rollback Test", - description="Testing rollback", - simulation_function=sim_func, - ) - db.set(model, commit=False) - # Raise an error to trigger rollback - raise ValueError("Test error") - except ValueError: - pass - - # Data should NOT be in database due to rollback - retrieved = db.get(Model, id="test_rollback_model") - assert retrieved is None - - -def test_database_url_variations(): - """Test that database works with different URL formats.""" - from policyengine.database import Database - - # Test in-memory SQLite - db1 = Database(url="sqlite:///:memory:") - assert db1.engine is not None - - # Test file-based SQLite - db2 = Database(url="sqlite:///test.db") - assert db2.engine is not None - - -def test_all_table_links_registered(fresh_database): - """Test that all expected table links are registered.""" - expected_count = 12 # Based on the number of table links in __init__ - assert len(fresh_database._model_table_links) == expected_count - - # Verify specific table links exist - from policyengine.models import ( - Aggregate, - Dataset, - Dynamic, - Model, - ModelVersion, - Parameter, - Policy, - Simulation, - VersionedDataset, - ) - - model_classes = [link.model_cls for link in fresh_database._model_table_links] - - assert Model in model_classes - assert ModelVersion in model_classes - assert Dataset in model_classes - assert VersionedDataset in model_classes - assert Policy in model_classes - assert Dynamic in model_classes - assert Parameter in model_classes - assert Simulation in model_classes - assert Aggregate in model_classes - - -def test_verify_tables_exist(fresh_database): - """Test the verify_tables_exist method.""" - # Before creating tables - results_before = fresh_database.verify_tables_exist() - # Some tables may exist from register_table calls during __init__ - # So we just check the method runs - - # After creating tables - fresh_database.create_tables() - results_after = fresh_database.verify_tables_exist() - - # All tables should exist now - assert all(results_after.values()), f"Some tables don't exist: {results_after}" - - # Check specific tables - assert results_after.get("models") is True - assert results_after.get("simulations") is True - assert results_after.get("parameters") is True diff --git a/tests/test_database_models.py b/tests/test_database_models.py deleted file mode 100644 index ed6473c0..00000000 --- a/tests/test_database_models.py +++ /dev/null @@ -1,384 +0,0 @@ -"""Test database model tables for set and get operations.""" - -import sys -from datetime import datetime -from pathlib import Path - -import pytest - -# Add src to path to allow imports -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - - -# Define functions at module level to make them pickleable (not test_ prefix to avoid pytest) -def simulation_function(x): - return x * 2 - - -def policy_modifier_function(sim): - sim.set_parameter("tax_rate", 0.25) - return sim - - -def dynamic_modifier_function(sim): - sim.set_parameter("benefit_amount", 1000) - return sim - - -@pytest.fixture -def fresh_database(): - """Create a fresh database instance for each test.""" - # Import here to avoid circular imports - from policyengine.database import Database - - # Use in-memory SQLite for testing - db = Database(url="sqlite:///:memory:") - db.create_tables() - return db - - -def test_model_table_set_and_get(fresh_database): - """Test ModelTable set and get operations.""" - from policyengine.models import Model - - model = Model( - id="test_model_1", - name="Test model", - description="A test model", - simulation_function=simulation_function, - ) - - # Set the model - fresh_database.set(model) - - # Get the model - retrieved_model = fresh_database.get(Model, id="test_model_1") - - assert retrieved_model is not None - assert retrieved_model.id == "test_model_1" - assert retrieved_model.name == "Test model" - assert retrieved_model.description == "A test model" - assert retrieved_model.simulation_function(5) == 10 - - -def test_dataset_table_set_and_get(fresh_database): - """Test DatasetTable set and get operations.""" - from policyengine.models import Dataset - - test_data = {"households": [{"id": 1, "income": 50000}]} - - dataset = Dataset( - id="test_dataset_1", - name="Test dataset", - data=test_data, - ) - - fresh_database.set(dataset) - retrieved = fresh_database.get(Dataset, id="test_dataset_1") - - assert retrieved is not None - assert retrieved.id == "test_dataset_1" - assert retrieved.name == "Test dataset" - assert retrieved.data == test_data - - -def test_versioned_dataset_table_set_and_get(fresh_database): - """Test VersionedDatasetTable set and get operations.""" - from policyengine.models import VersionedDataset - - versioned_dataset = VersionedDataset( - id="test_versioned_1", - name="Test versioned dataset", - description="A test versioned dataset", - ) - - fresh_database.set(versioned_dataset) - retrieved = fresh_database.get(VersionedDataset, id="test_versioned_1") - - assert retrieved is not None - assert retrieved.id == "test_versioned_1" - assert retrieved.name == "Test versioned dataset" - assert retrieved.description == "A test versioned dataset" - - -def test_policy_table_set_and_get(fresh_database): - """Test PolicyTable set and get operations.""" - from policyengine.models import Policy - - policy = Policy( - id="test_policy_1", - name="Test policy", - description="A test policy", - simulation_modifier=policy_modifier_function, - created_at=datetime.now(), - ) - - fresh_database.set(policy) - retrieved = fresh_database.get(Policy, id="test_policy_1") - - assert retrieved is not None - assert retrieved.id == "test_policy_1" - assert retrieved.name == "Test policy" - assert retrieved.description == "A test policy" - assert callable(retrieved.simulation_modifier) - - -def test_dynamic_table_set_and_get(fresh_database): - """Test DynamicTable set and get operations.""" - from policyengine.models import Dynamic - - dynamic = Dynamic( - id="test_dynamic_1", - name="Test dynamic", - description="A test dynamic policy", - simulation_modifier=dynamic_modifier_function, - created_at=datetime.now(), - ) - - fresh_database.set(dynamic) - retrieved = fresh_database.get(Dynamic, id="test_dynamic_1") - - assert retrieved is not None - assert retrieved.id == "test_dynamic_1" - assert retrieved.name == "Test dynamic" - assert retrieved.description == "A test dynamic policy" - assert callable(retrieved.simulation_modifier) - - -def test_baseline_parameter_value_table_set_and_get(fresh_database): - """Test BaselineParameterValueTable set and get operations.""" - from policyengine.models import ( - BaselineParameterValue, - Model, - ModelVersion, - Parameter, - ) - - # Create model, parameter and model version first - model = Model( - id="bpv_model", - name="BPV model", - description="Model for baseline parameter values", - simulation_function=simulation_function, - ) - fresh_database.set(model) - - parameter = Parameter( - id="test_param_1", - description="Test parameter", - data_type=float, - model=model, - ) - fresh_database.set(parameter) - - model_version = ModelVersion( - id="test_version_1", - model=model, - version="1.0.0", - ) - fresh_database.set(model_version) - - baseline_param = BaselineParameterValue( - parameter=parameter, - model_version=model_version, - value=0.2, - start_date=datetime(2024, 1, 1), - ) - - fresh_database.set(baseline_param) - # Note: BaselineParameterValue doesn't have an id field, so we need to query differently - # For now, we'll skip retrieval test for this model - # TODO: Add proper retrieval test for composite key models - - -def test_multiple_operations_on_same_table(fresh_database): - """Test multiple set and get operations on the same table.""" - from policyengine.models import Model - - # Create multiple model instances - models = [] - for i in range(3): - model = Model( - id=f"model_{i}", - name=f"Model {i}", - description=f"Model number {i}", - simulation_function=simulation_function, - ) - models.append(model) - fresh_database.set(model) - - # Retrieve all models - for i, original in enumerate(models): - retrieved = fresh_database.get(Model, id=f"model_{i}") - assert retrieved is not None - assert retrieved.id == original.id - assert retrieved.name == original.name - assert retrieved.description == original.description - - -def test_get_nonexistent_record(fresh_database): - """Test getting a record that doesn't exist.""" - from policyengine.models import Model - - result = fresh_database.get(Model, id="nonexistent_id") - assert result is None - - -def test_complex_data_compression(fresh_database): - """Test that complex data types are properly compressed and decompressed.""" - from policyengine.models import Dataset - - # Create a dataset with complex nested structure - complex_data = { - "households": [ - { - "id": i, - "income": 30000 + i * 5000, - "members": list(range(i + 1)), - } - for i in range(100) - ], - "metadata": { - "source": "test", - "year": 2024, - "nested": {"deep": {"structure": True}}, - }, - } - - dataset = Dataset( - id="complex_dataset", - name="Complex dataset", - data=complex_data, - ) - - fresh_database.set(dataset) - retrieved = fresh_database.get(Dataset, id="complex_dataset") - - assert retrieved is not None - assert retrieved.data == complex_data - assert retrieved.data["households"][50]["income"] == 280000 - assert retrieved.data["metadata"]["nested"]["deep"]["structure"] is True - - -def test_user_table_set_and_get(fresh_database): - """Test UserTable set and get operations.""" - from policyengine.models import User - - user = User( - id="test_user_1", - username="testuser", - email="test@example.com", - first_name="Test", - last_name="User", - created_at=datetime.now(), - ) - - fresh_database.set(user) - retrieved = fresh_database.get(User, id="test_user_1") - - assert retrieved is not None - assert retrieved.id == "test_user_1" - assert retrieved.username == "testuser" - assert retrieved.email == "test@example.com" - assert retrieved.first_name == "Test" - assert retrieved.last_name == "User" - - -def test_report_table_set_and_get(fresh_database): - """Test ReportTable set and get operations.""" - from policyengine.models import Report - - report = Report( - id="test_report_1", - label="Test Report", - created_at=datetime.now(), - ) - - fresh_database.set(report) - retrieved = fresh_database.get(Report, id="test_report_1") - - assert retrieved is not None - assert retrieved.id == "test_report_1" - assert retrieved.label == "Test Report" - - -def test_report_element_table_set_and_get(fresh_database): - """Test ReportElementTable set and get operations.""" - from policyengine.models import ReportElement - - element = ReportElement( - id="test_element_1", - label="Test Element", - type="chart", - chart_type="bar", - ) - - fresh_database.set(element) - retrieved = fresh_database.get(ReportElement, id="test_element_1") - - assert retrieved is not None - assert retrieved.id == "test_element_1" - assert retrieved.label == "Test Element" - assert retrieved.type == "chart" - assert retrieved.chart_type == "bar" - - -def test_aggregate_table_set_and_get(fresh_database): - """Test AggregateTable set and get operations.""" - from policyengine.models import ( - Aggregate, - Dataset, - Model, - ModelVersion, - Simulation, - ) - - # Create model first - model = Model( - id="agg_model", - name="Agg model", - description="Model for aggregates", - simulation_function=simulation_function, - ) - fresh_database.set(model) - - # Create model version - model_version = ModelVersion( - id="agg_version_1", - model=model, - version="1.0.0", - ) - fresh_database.set(model_version) - - # Create dataset - test_data = {"households": [{"id": 1, "income": 50000}]} - dataset = Dataset( - id="agg_dataset_1", - name="Agg dataset", - data=test_data, - ) - fresh_database.set(dataset) - - # Create simulation - simulation = Simulation( - id="agg_sim_1", - dataset=dataset, - model=model, - model_version=model_version, - ) - fresh_database.set(simulation) - - aggregate = Aggregate( - simulation=simulation, - entity="household", - variable_name="household_income", - aggregate_function="sum", - year=2024, - filter_variable_name="income_positive", - ) - - fresh_database.set(aggregate) - # Note: Aggregate doesn't have an id field or direct retrieval - # We'll skip retrieval test for now - # TODO: Add proper retrieval test for Aggregate model - assert True # Placeholder assertion diff --git a/tests/test_database_postgres.py b/tests/test_database_postgres.py deleted file mode 100644 index 1af9796d..00000000 --- a/tests/test_database_postgres.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Test database with PostgreSQL/Supabase connection. - -These tests verify that table creation and commits work properly with PostgreSQL, -which is what Supabase uses. -""" - -import sys -from pathlib import Path - -import pytest -from sqlalchemy import inspect, text - -# Add src to path to allow imports -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - - -# Module-level functions for testing (can be pickled) -def sim_func(x): - """Simulation function that can be pickled.""" - return x - - -def sim_func_double(x): - """Simulation function that doubles input.""" - return x * 2 - - -@pytest.fixture -def postgres_database(): - """Create a database instance with a local PostgreSQL connection. - - This requires a local PostgreSQL server running on port 54322. - Skip the test if the connection fails. - """ - from policyengine.database import Database - - try: - db = Database(url="postgresql://postgres:postgres@127.0.0.1:54322/postgres") - # Test connection - with db.engine.connect() as conn: - conn.execute(text("SELECT 1")) - return db - except Exception as e: - pytest.skip(f"PostgreSQL not available: {e}") - - -def test_postgres_create_tables(postgres_database): - """Test that create_tables() works with PostgreSQL.""" - # Drop tables first to ensure clean state - postgres_database.drop_tables() - - # Create tables - postgres_database.create_tables() - - # Verify tables exist - inspector = inspect(postgres_database.engine) - actual_tables = inspector.get_table_names() - - expected_tables = { - "models", - "model_versions", - "datasets", - "versioned_datasets", - "policies", - "dynamics", - "parameters", - "parameter_values", - "baseline_parameter_values", - "baseline_variables", - "simulations", - "aggregates", - } - - for table in expected_tables: - assert table in actual_tables, f"Table {table} was not created in PostgreSQL" - - -def test_postgres_insert_and_retrieve(postgres_database): - """Test that data can be inserted and retrieved from PostgreSQL.""" - from policyengine.models import Model - - # Reset database - postgres_database.reset() - - # Create a model - model = Model( - id="postgres_test_model", - name="PostgreSQL Test", - description="Testing PostgreSQL", - simulation_function=sim_func_double, - ) - - # Insert - postgres_database.set(model) - - # Retrieve - retrieved = postgres_database.get(Model, id="postgres_test_model") - - assert retrieved is not None - assert retrieved.id == "postgres_test_model" - assert retrieved.name == "PostgreSQL Test" - assert retrieved.simulation_function(5) == 10 - - -def test_postgres_session_commit(postgres_database): - """Test that session commits work properly with PostgreSQL.""" - from policyengine.models import Model - - # Reset database - postgres_database.reset() - - # Add data without committing - model = Model( - id="commit_test_model", - name="Commit Test", - description="Testing commit", - simulation_function=sim_func, - ) - - # Set with explicit commit - postgres_database.set(model, commit=True) - - # Create a NEW database connection to verify commit - from policyengine.database import Database - - new_db = Database(url="postgresql://postgres:postgres@127.0.0.1:54322/postgres") - retrieved = new_db.get(Model, id="commit_test_model") - - assert retrieved is not None - assert retrieved.name == "Commit Test" - - -def test_postgres_table_persistence(postgres_database): - """Test that tables persist across database reconnections.""" - # Create tables - postgres_database.reset() - - # Close connection - postgres_database.session.close() - - # Create new database instance with same URL - from policyengine.database import Database - - new_db = Database(url="postgresql://postgres:postgres@127.0.0.1:54322/postgres") - - # Tables should still exist - inspector = inspect(new_db.engine) - tables = inspector.get_table_names() - - assert "models" in tables - assert "simulations" in tables - - -def test_postgres_register_model_version(postgres_database): - """Test that register_model_version works with PostgreSQL.""" - # This test verifies that the bulk registration of model versions - # properly commits to PostgreSQL - postgres_database.reset() - - # This would typically use policyengine_uk_latest_version - # For now, we'll just verify the reset worked - inspector = inspect(postgres_database.engine) - tables = inspector.get_table_names() - - assert "models" in tables - assert "model_versions" in tables - assert "parameters" in tables - assert "baseline_parameter_values" in tables - assert "baseline_variables" in tables diff --git a/tests/test_database_simple.py b/tests/test_database_simple.py deleted file mode 100644 index bb890a14..00000000 --- a/tests/test_database_simple.py +++ /dev/null @@ -1,277 +0,0 @@ -"""Test database model tables for simple set and get operations without complex relationships.""" - -import sys -from datetime import datetime -from pathlib import Path - -import pytest - -# Add src to path to allow imports -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - - -# Define functions at module level to make them pickleable (not test_ prefix to avoid pytest) -def simulation_function(x): - return x * 2 - - -def policy_modifier_function(sim): - sim.set_parameter("tax_rate", 0.25) - return sim - - -def dynamic_modifier_function(sim): - sim.set_parameter("benefit_amount", 1000) - return sim - - -@pytest.fixture -def fresh_database(): - """Create a fresh database instance for each test.""" - # Import here to avoid circular imports - from policyengine.database import Database - - # Use in-memory SQLite for testing - db = Database(url="sqlite:///:memory:") - db.create_tables() - return db - - -def test_model_table_set_and_get(fresh_database): - """Test ModelTable set and get operations.""" - from policyengine.models import Model - - model = Model( - id="test_model_1", - name="Test model", - description="A test model", - simulation_function=simulation_function, - ) - - # Set the model - fresh_database.set(model) - - # Get the model - retrieved_model = fresh_database.get(Model, id="test_model_1") - - assert retrieved_model is not None - assert retrieved_model.id == "test_model_1" - assert retrieved_model.name == "Test model" - assert retrieved_model.description == "A test model" - assert retrieved_model.simulation_function(5) == 10 - - -def test_dataset_table_set_and_get(fresh_database): - """Test DatasetTable set and get operations.""" - from policyengine.models import Dataset - - test_data = {"households": [{"id": 1, "income": 50000}]} - - dataset = Dataset( - id="test_dataset_1", - name="Test dataset", - data=test_data, - ) - - fresh_database.set(dataset) - retrieved = fresh_database.get(Dataset, id="test_dataset_1") - - assert retrieved is not None - assert retrieved.id == "test_dataset_1" - assert retrieved.name == "Test dataset" - assert retrieved.data == test_data - - -def test_versioned_dataset_table_set_and_get(fresh_database): - """Test VersionedDatasetTable set and get operations.""" - from policyengine.models import VersionedDataset - - versioned_dataset = VersionedDataset( - id="test_versioned_1", - name="Test versioned dataset", - description="A test versioned dataset", - ) - - fresh_database.set(versioned_dataset) - retrieved = fresh_database.get(VersionedDataset, id="test_versioned_1") - - assert retrieved is not None - assert retrieved.id == "test_versioned_1" - assert retrieved.name == "Test versioned dataset" - assert retrieved.description == "A test versioned dataset" - - -def test_policy_table_set_and_get(fresh_database): - """Test PolicyTable set and get operations.""" - from policyengine.models import Policy - - policy = Policy( - id="test_policy_1", - name="Test policy", - description="A test policy", - simulation_modifier=policy_modifier_function, - created_at=datetime.now(), - ) - - fresh_database.set(policy) - retrieved = fresh_database.get(Policy, id="test_policy_1") - - assert retrieved is not None - assert retrieved.id == "test_policy_1" - assert retrieved.name == "Test policy" - assert retrieved.description == "A test policy" - assert callable(retrieved.simulation_modifier) - - -def test_dynamic_table_set_and_get(fresh_database): - """Test DynamicTable set and get operations.""" - from policyengine.models import Dynamic - - dynamic = Dynamic( - id="test_dynamic_1", - name="Test dynamic", - description="A test dynamic policy", - simulation_modifier=dynamic_modifier_function, - created_at=datetime.now(), - ) - - fresh_database.set(dynamic) - retrieved = fresh_database.get(Dynamic, id="test_dynamic_1") - - assert retrieved is not None - assert retrieved.id == "test_dynamic_1" - assert retrieved.name == "Test dynamic" - assert retrieved.description == "A test dynamic policy" - assert callable(retrieved.simulation_modifier) - - -def test_user_table_set_and_get(fresh_database): - """Test UserTable set and get operations.""" - from policyengine.models import User - - user = User( - id="test_user_1", - username="testuser", - email="test@example.com", - first_name="Test", - last_name="User", - created_at=datetime.now(), - ) - - fresh_database.set(user) - retrieved = fresh_database.get(User, id="test_user_1") - - assert retrieved is not None - assert retrieved.id == "test_user_1" - assert retrieved.username == "testuser" - assert retrieved.email == "test@example.com" - assert retrieved.first_name == "Test" - assert retrieved.last_name == "User" - - -def test_report_table_set_and_get(fresh_database): - """Test ReportTable set and get operations.""" - from policyengine.models import Report - - report = Report( - id="test_report_1", - label="Test Report", - created_at=datetime.now(), - ) - - fresh_database.set(report) - retrieved = fresh_database.get(Report, id="test_report_1") - - assert retrieved is not None - assert retrieved.id == "test_report_1" - assert retrieved.label == "Test Report" - - -def test_report_element_table_set_and_get(fresh_database): - """Test ReportElementTable set and get operations.""" - from policyengine.models import ReportElement - - element = ReportElement( - id="test_element_1", - label="Test Element", - type="chart", - chart_type="bar", - ) - - fresh_database.set(element) - retrieved = fresh_database.get(ReportElement, id="test_element_1") - - assert retrieved is not None - assert retrieved.id == "test_element_1" - assert retrieved.label == "Test Element" - assert retrieved.type == "chart" - assert retrieved.chart_type == "bar" - - -def test_multiple_operations_on_same_table(fresh_database): - """Test multiple set and get operations on the same table.""" - from policyengine.models import Model - - # Create multiple model instances - models = [] - for i in range(3): - model = Model( - id=f"model_{i}", - name=f"Model {i}", - description=f"Model number {i}", - simulation_function=simulation_function, - ) - models.append(model) - fresh_database.set(model) - - # Retrieve all models - for i, original in enumerate(models): - retrieved = fresh_database.get(Model, id=f"model_{i}") - assert retrieved is not None - assert retrieved.id == original.id - assert retrieved.name == original.name - assert retrieved.description == original.description - - -def test_get_nonexistent_record(fresh_database): - """Test getting a record that doesn't exist.""" - from policyengine.models import Model - - result = fresh_database.get(Model, id="nonexistent_id") - assert result is None - - -def test_complex_data_compression(fresh_database): - """Test that complex data types are properly compressed and decompressed.""" - from policyengine.models import Dataset - - # Create a dataset with complex nested structure - complex_data = { - "households": [ - { - "id": i, - "income": 30000 + i * 5000, - "members": list(range(i + 1)), - } - for i in range(100) - ], - "metadata": { - "source": "test", - "year": 2024, - "nested": {"deep": {"structure": True}}, - }, - } - - dataset = Dataset( - id="complex_dataset", - name="Complex dataset", - data=complex_data, - ) - - fresh_database.set(dataset) - retrieved = fresh_database.get(Dataset, id="complex_dataset") - - assert retrieved is not None - assert retrieved.data == complex_data - assert retrieved.data["households"][50]["income"] == 280000 - assert retrieved.data["metadata"]["nested"]["deep"]["structure"] is True From 6f6e31b84d6e0add1c00c3128a6baa84bae79c36 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sat, 8 Nov 2025 11:52:04 +0000 Subject: [PATCH 14/35] Folder rename --- db_init.py | 15 - docs/models.md | 47 - docs/myst.yml | 1 - docs/quickstart.ipynb | 2261 ----------------- src/policyengine/{models => core}/__init__.py | 0 src/policyengine/{models => core}/dataset.py | 0 .../{models => core}/dataset_version.py | 0 src/policyengine/{models => core}/dynamic.py | 0 .../{models => core}/parameter.py | 0 .../{models => core}/parameter_value.py | 0 src/policyengine/{models => core}/policy.py | 0 .../{models => core}/simulation.py | 0 .../{models => core}/tax_benefit_model.py | 0 .../tax_benefit_model_version.py | 0 src/policyengine/{models => core}/variable.py | 0 15 files changed, 2324 deletions(-) delete mode 100644 db_init.py delete mode 100644 docs/models.md delete mode 100644 docs/quickstart.ipynb rename src/policyengine/{models => core}/__init__.py (100%) rename src/policyengine/{models => core}/dataset.py (100%) rename src/policyengine/{models => core}/dataset_version.py (100%) rename src/policyengine/{models => core}/dynamic.py (100%) rename src/policyengine/{models => core}/parameter.py (100%) rename src/policyengine/{models => core}/parameter_value.py (100%) rename src/policyengine/{models => core}/policy.py (100%) rename src/policyengine/{models => core}/simulation.py (100%) rename src/policyengine/{models => core}/tax_benefit_model.py (100%) rename src/policyengine/{models => core}/tax_benefit_model_version.py (100%) rename src/policyengine/{models => core}/variable.py (100%) diff --git a/db_init.py b/db_init.py deleted file mode 100644 index f6a3e672..00000000 --- a/db_init.py +++ /dev/null @@ -1,15 +0,0 @@ -from policyengine.database import Database -from policyengine.models.policyengine_uk import policyengine_uk_latest_version -from policyengine.utils.datasets import create_uk_dataset - -# Load the dataset - -uk_dataset = create_uk_dataset() - -database = Database("postgresql://postgres:postgres@127.0.0.1:54322/postgres") - -# These two lines are not usually needed, but you should use them the first time you set up a new database -database.reset() # Drop and recreate all tables -database.register_model_version( - policyengine_uk_latest_version -) # Add in the model, model version, parameters and baseline parameter values and variables. diff --git a/docs/models.md b/docs/models.md deleted file mode 100644 index 010037a3..00000000 --- a/docs/models.md +++ /dev/null @@ -1,47 +0,0 @@ -# PolicyEngine model types guide - -This repository contains several model types that work together to enable policy simulation and analysis. Here's what each does: - -## Core simulation models - -**Model** - The main computational engine that registers tax-benefit systems (UK/US) and provides the simulation function. Contains logic to create seed objects from tax-benefit parameters. - -**Simulation** - Orchestrates policy analysis by combining a model, dataset, policy changes, and dynamic effects. Runs the model's simulation function and stores results. - -**ModelVersion** - Tracks different versions of a model implementation, allowing for comparison across model iterations. - -## Policy configuration - -**Policy** - Defines policy reforms through parameter value changes. Can include a custom simulation modifier function for complex reforms. - -**Dynamic** - Similar to Policy but specifically for dynamic/behavioural responses to policy changes. - -**Parameter** - Represents a single policy parameter (e.g., tax rate, benefit amount) within a model. - -**ParameterValue** - A specific value for a parameter at a given time period. - -**BaselineParameterValue** - Default/baseline values for parameters before any policy changes. - -**BaselineVariable** - Variables in the baseline scenario used for comparison. - -## Data handling - -**Dataset** - Contains the input data (households, people, etc.) for a simulation, with optional versioning and year specification. - -**VersionedDataset** - Manages different versions of datasets over time. - -## Results and reporting - -**Report** - Container for analysis results with timestamp tracking. - -**ReportElement** - Individual components within a report (charts, tables, metrics). - -**Aggregate** - Computes aggregated statistics (sum, mean, count) from simulation results, with optional filtering. - -**AggregateType** - Enum defining the available aggregation functions. - -## Supporting models - -**User** - User account management for the platform. - -**SeedObjects** - Helper class for batch-creating initial database objects when registering a new model. \ No newline at end of file diff --git a/docs/myst.yml b/docs/myst.yml index c8ccc8d5..ea5176cf 100644 --- a/docs/myst.yml +++ b/docs/myst.yml @@ -12,7 +12,6 @@ project: # Auto-generated by `myst init --write-toc` - file: index.md - file: quickstart.ipynb - - file: models.md - file: dev.md site: diff --git a/docs/quickstart.ipynb b/docs/quickstart.ipynb deleted file mode 100644 index 8c497bff..00000000 --- a/docs/quickstart.ipynb +++ /dev/null @@ -1,2261 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "b5510438", - "metadata": {}, - "source": [ - "# Getting started\n", - "\n", - "In this notebook, we'll walk through how to use the PolicyEngine.py package to run simulations and produce analyses. We'll start with a basic analysis in the UK that doesn't use any databases, and then start saving and loading things into a database.\n", - "\n", - "## Running a baseline simulation\n", - "\n", - "To start, let's run through a simulation of the UK, and create a chart of the distribution of household income." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7eb9b5a0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " \n", - " \n", - " \n", - " " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "marker": { - "color": "#319795" - }, - "textfont": { - "color": "#6B7280", - "family": "Roboto Mono, monospace", - "size": 11 - }, - "textposition": "outside", - "texttemplate": "%{y:,.0f}", - "type": "bar", - "x": [ - "£0", - "£20,000", - "£40,000", - "£60,000", - "£80,000", - "£100,000", - "£150,000", - "£200,000", - "£300,000", - "£500,000" - ], - "y": [ - 6530423.253505196, - 10205681.438694796, - 6918333.897778195, - 4101047.3396776896, - 1656640.5745191968, - 1312315.5343185724, - 706991.8843209555, - 277644.11414299323, - 72024.26234725268, - 34894.54357677698 - ] - } - ], - "layout": { - "font": { - "color": "#101828", - "family": "Roboto, sans-serif", - "size": 14 - }, - "hoverlabel": { - "bgcolor": "white", - "bordercolor": "#81E6D9", - "font": { - "family": "Roboto Mono, monospace", - "size": 12 - } - }, - "hovermode": "x unified", - "paper_bgcolor": "white", - "plot_bgcolor": "white", - "showlegend": false, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "font": { - "color": "#101828", - "family": "Roboto, sans-serif", - "size": 20, - "weight": 500 - }, - "text": "The distribution of household income in the UK" - }, - "xaxis": { - "gridcolor": "#E2E8F0", - "gridwidth": 1, - "linecolor": "#E2E8F0", - "linewidth": 1, - "mirror": false, - "showgrid": true, - "showline": true, - "tickfont": { - "color": "#6B7280", - "family": "Roboto Mono, monospace", - "size": 11 - }, - "title": { - "font": { - "color": "#6B7280" - }, - "text": "Income range" - }, - "zeroline": true, - "zerolinecolor": "#F2F4F7", - "zerolinewidth": 1 - }, - "yaxis": { - "gridcolor": "#E2E8F0", - "gridwidth": 1, - "linecolor": "#E2E8F0", - "linewidth": 1, - "mirror": false, - "showgrid": true, - "showline": true, - "tickfont": { - "color": "#6B7280", - "family": "Roboto Mono, monospace", - "size": 11 - }, - "title": { - "font": { - "color": "#6B7280" - }, - "text": "Number of households" - }, - "zeroline": true, - "zerolinecolor": "#F2F4F7", - "zerolinewidth": 1 - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import plotly.graph_objects as go\n", - "\n", - "from policyengine.models import (\n", - " Aggregate,\n", - " Simulation,\n", - " policyengine_uk_latest_version,\n", - " policyengine_uk_model,\n", - ")\n", - "from policyengine.utils.charts import add_fonts, format_figure\n", - "from policyengine.utils.datasets import create_uk_dataset\n", - "\n", - "# Load the dataset\n", - "\n", - "uk_dataset = create_uk_dataset()\n", - "\n", - "# Create and run the simulation\n", - "\n", - "\n", - "sim = Simulation(\n", - " dataset=uk_dataset,\n", - " model=policyengine_uk_model,\n", - " model_version=policyengine_uk_latest_version,\n", - ")\n", - "\n", - "sim.run()\n", - "\n", - "# Extract aggregates for household income ranges\n", - "\n", - "income_ranges = [\n", - " 0,\n", - " 20000,\n", - " 40000,\n", - " 60000,\n", - " 80000,\n", - " 100000,\n", - " 150000,\n", - " 200000,\n", - " 300000,\n", - " 500000,\n", - " 1_000_000,\n", - "]\n", - "aggregates = []\n", - "for i in range(len(income_ranges) - 1):\n", - " aggregates.append(\n", - " Aggregate(\n", - " entity=\"household\",\n", - " variable_name=\"hbai_household_net_income\",\n", - " aggregate_function=\"count\",\n", - " filter_variable_name=\"hbai_household_net_income\",\n", - " filter_variable_geq=income_ranges[i],\n", - " filter_variable_leq=income_ranges[i + 1],\n", - " simulation=sim,\n", - " )\n", - " )\n", - "\n", - "aggregates = Aggregate.run(aggregates)\n", - "\n", - "# Create the bar chart\n", - "\n", - "fig = go.Figure(\n", - " data=[\n", - " go.Bar(\n", - " x=[f\"£{inc:,}\" for inc in income_ranges[:-1]],\n", - " y=[agg.value for agg in aggregates],\n", - " )\n", - " ]\n", - ")\n", - "\n", - "# Apply formatting\n", - "\n", - "format_figure(\n", - " fig,\n", - " title=\"The distribution of household income in the UK\",\n", - " x_title=\"Income range\",\n", - " y_title=\"Number of households\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "24ba497b", - "metadata": {}, - "source": [ - "So, in this example we introduced a few concepts:\n", - "\n", - "* The `Simulation` object, which represents a full run of a microsimulation model, containing all the information (simulated and input) about a set of people or groups. It takes here a few arguments: a `Dataset`, `Model` and `ModelVersion`.\n", - "* The `Dataset` object, which represents a set of people or groups. Here we used a utility function to create this dataset for the UK, but we later will be able to create these from scratch or pull them from a database.\n", - "* The `Model` object, which represents a particular microsimulation model (essentially defined as a function transforming a dataset to a new dataset). There are two models defined by this package, one for the UK and one for the US. Think of these objects as adapters representing the full microsimulation models. Here, we've taken the pre-defined UK model.\n", - "* The `ModelVersion` object, which represents a particular version of a model. This is useful for tracking changes to the model over time. Here, we used the latest version of the UK model.\n", - "\n", - "## Adding a policy reform\n", - "\n", - "Next, we'll add in a policy reform, and see how that changes the results." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "b40913b2", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "marker": { - "color": "#319795" - }, - "name": "Baseline", - "textfont": { - "color": "#6B7280", - "family": "Roboto Mono, monospace", - "size": 11 - }, - "textposition": "outside", - "texttemplate": "%{y:,.0f}", - "type": "bar", - "x": [ - "£0", - "£20,000", - "£40,000", - "£60,000", - "£80,000", - "£100,000", - "£150,000", - "£200,000", - "£300,000", - "£500,000" - ], - "y": [ - 6530423.253505196, - 10205681.438694796, - 6918333.897778195, - 4101047.3396776896, - 1656640.5745191968, - 1312315.5343185724, - 706991.8843209555, - 277644.11414299323, - 72024.26234725268, - 34894.54357677698 - ] - }, - { - "marker": { - "color": "#0EA5E9" - }, - "name": "Reform", - "textfont": { - "color": "#6B7280", - "family": "Roboto Mono, monospace", - "size": 11 - }, - "textposition": "outside", - "texttemplate": "%{y:,.0f}", - "type": "bar", - "x": [ - "£0", - "£20,000", - "£40,000", - "£60,000", - "£80,000", - "£100,000", - "£150,000", - "£200,000", - "£300,000", - "£500,000" - ], - "y": [ - 6131768.670854956, - 10113037.29630455, - 6805540.101212463, - 4269018.780282864, - 1910859.1447649547, - 1503954.6294220393, - 715336.0493238519, - 283341.39915852225, - 72024.26234725268, - 34894.54357677698 - ] - } - ], - "layout": { - "font": { - "color": "#101828", - "family": "Roboto, sans-serif", - "size": 14 - }, - "hoverlabel": { - "bgcolor": "white", - "bordercolor": "#81E6D9", - "font": { - "family": "Roboto Mono, monospace", - "size": 12 - } - }, - "hovermode": "x unified", - "paper_bgcolor": "white", - "plot_bgcolor": "white", - "showlegend": true, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "font": { - "color": "#101828", - "family": "Roboto, sans-serif", - "size": 20, - "weight": 500 - }, - "text": "The distribution of household income in the UK" - }, - "xaxis": { - "gridcolor": "#E2E8F0", - "gridwidth": 1, - "linecolor": "#E2E8F0", - "linewidth": 1, - "mirror": false, - "showgrid": true, - "showline": true, - "tickfont": { - "color": "#6B7280", - "family": "Roboto Mono, monospace", - "size": 11 - }, - "title": { - "font": { - "color": "#6B7280" - }, - "text": "Income range" - }, - "zeroline": true, - "zerolinecolor": "#F2F4F7", - "zerolinewidth": 1 - }, - "yaxis": { - "gridcolor": "#E2E8F0", - "gridwidth": 1, - "linecolor": "#E2E8F0", - "linewidth": 1, - "mirror": false, - "showgrid": true, - "showline": true, - "tickfont": { - "color": "#6B7280", - "family": "Roboto Mono, monospace", - "size": 11 - }, - "title": { - "font": { - "color": "#6B7280" - }, - "text": "Number of households" - }, - "zeroline": true, - "zerolinecolor": "#F2F4F7", - "zerolinewidth": 1 - } - } - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from datetime import datetime\n", - "\n", - "from policyengine.models import Parameter, ParameterValue, Policy\n", - "\n", - "# Parameter = the parameter to change\n", - "\n", - "personal_allowance = Parameter(\n", - " id=\"gov.hmrc.income_tax.allowances.personal_allowance.amount\",\n", - " model=policyengine_uk_model,\n", - ")\n", - "\n", - "# ParameterValue = the value to set the parameter to, and when to start\n", - "\n", - "personal_allowance_value = ParameterValue(\n", - " parameter=personal_allowance,\n", - " start_date=datetime(2029, 1, 1),\n", - " value=20000,\n", - ")\n", - "\n", - "# Create a policy to increase the personal allowance to £20,000 from 2029-30\n", - "\n", - "policy = Policy(\n", - " name=\"Increase personal allowance to £20,000\",\n", - " description=\"A policy to increase the personal allowance for income tax to £20,000.\",\n", - " parameter_values=[personal_allowance_value],\n", - ")\n", - "\n", - "sim_2 = Simulation(\n", - " dataset=uk_dataset,\n", - " model=policyengine_uk_model,\n", - " model_version=policyengine_uk_latest_version,\n", - " policy=policy, # Pass in the policy here\n", - ")\n", - "\n", - "sim_2.run()\n", - "\n", - "# Extract new aggregates for household income ranges\n", - "\n", - "income_ranges = [\n", - " 0,\n", - " 20000,\n", - " 40000,\n", - " 60000,\n", - " 80000,\n", - " 100000,\n", - " 150000,\n", - " 200000,\n", - " 300000,\n", - " 500000,\n", - " 1_000_000,\n", - "]\n", - "aggregates_2 = []\n", - "for i in range(len(income_ranges) - 1):\n", - " aggregates_2.append(\n", - " Aggregate(\n", - " entity=\"household\",\n", - " variable_name=\"hbai_household_net_income\",\n", - " aggregate_function=\"count\",\n", - " filter_variable_name=\"hbai_household_net_income\",\n", - " filter_variable_geq=income_ranges[i],\n", - " filter_variable_leq=income_ranges[i + 1],\n", - " simulation=sim_2,\n", - " )\n", - " )\n", - "\n", - "aggregates_2 = Aggregate.run(aggregates_2)\n", - "\n", - "# Create the comparative bar chart\n", - "fig = go.Figure(\n", - " data=[\n", - " go.Bar(\n", - " name=\"Baseline\",\n", - " x=[f\"£{inc:,}\" for inc in income_ranges[:-1]],\n", - " y=[agg.value for agg in aggregates],\n", - " ),\n", - " go.Bar(\n", - " name=\"Reform\",\n", - " x=[f\"£{inc:,}\" for inc in income_ranges[:-1]],\n", - " y=[agg.value for agg in aggregates_2],\n", - " ),\n", - " ]\n", - ")\n", - "\n", - "# Apply formatting\n", - "fig = format_figure(\n", - " fig,\n", - " title=\"The distribution of household income in the UK\",\n", - " x_title=\"Income range\",\n", - " y_title=\"Number of households\",\n", - ")\n", - "\n", - "add_fonts()\n", - "\n", - "fig" - ] - }, - { - "cell_type": "markdown", - "id": "6c029d3b", - "metadata": {}, - "source": [ - "In the above example, we created a `Policy` object, which represents a particular policy reform. This object contains a list of `ParameterValue` objects, which represent changes to specific parameters in the model. Here, we changed the personal allowance for income tax to £20,000.\n", - "\n", - "## Bringing in a database\n", - "\n", - "Now, we can upload these objects to a database, and then load them back out again. This is useful for tracking different simulations and policy reforms over time." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "ac6c443e", - "metadata": {}, - "outputs": [], - "source": [ - "from policyengine.database import Database\n", - "from policyengine.models.policyengine_uk import policyengine_uk_latest_version\n", - "\n", - "database = Database(\"postgresql://postgres:postgres@127.0.0.1:54322/postgres\")\n", - "\n", - "# These two lines are not usually needed, but you should use them the first time you set up a new database\n", - "database.reset() # Drop and recreate all tables\n", - "database.register_model_version(\n", - " policyengine_uk_latest_version\n", - ") # Add in the model, model version, parameters and baseline parameter values and variables.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "f14c85eb", - "metadata": {}, - "outputs": [], - "source": [ - "from policyengine.database import Database\n", - "\n", - "database = Database(\"postgresql://postgres:postgres@127.0.0.1:54322/postgres\")\n", - "\n", - "# These two lines are not usually needed, but you should use them the first time you set up a new database\n", - "database.reset() # Drop and recreate all tables\n", - "database.register_model_version(\n", - " policyengine_uk_latest_version\n", - ") # Add in the model, model version, parameters and baseline parameter values and variables.\n", - "\n", - "database.set(uk_dataset)\n", - "database.set(policy)\n", - "\n", - "for pv in policy.parameter_values:\n", - " database.set(pv)\n", - "database.set(sim)\n", - "database.set(sim_2)\n", - "for agg in aggregates:\n", - " database.set(agg)\n", - "for agg in aggregates_2:\n", - " database.set(agg)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "2041dfeb", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Policy(id='77545b6d-1294-4f7d-8646-cbc6d9d4b054', name='Increase personal allowance to £20,000', description='A policy to increase the personal allowance for income tax to £20,000.', parameter_values=[ParameterValue(id='0416cef8-f6f1-4925-bba6-60c2527838ae', parameter=Parameter(id='gov.hmrc.income_tax.allowances.personal_allowance.amount', description=None, data_type=None, model=Model(id='policyengine_uk', name='PolicyEngine UK', description=\"PolicyEngine's open-source tax-benefit microsimulation model.\", simulation_function=), label=None, unit=None), value=20000, start_date=datetime.datetime(2029, 1, 1, 0, 0), end_date=None)], simulation_modifier=None, created_at=datetime.datetime(2025, 10, 3, 17, 46, 47, 141804), updated_at=datetime.datetime(2025, 10, 3, 17, 46, 47, 141809))" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "database.get(Policy, id=policy.id)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "policyengine", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/policyengine/models/__init__.py b/src/policyengine/core/__init__.py similarity index 100% rename from src/policyengine/models/__init__.py rename to src/policyengine/core/__init__.py diff --git a/src/policyengine/models/dataset.py b/src/policyengine/core/dataset.py similarity index 100% rename from src/policyengine/models/dataset.py rename to src/policyengine/core/dataset.py diff --git a/src/policyengine/models/dataset_version.py b/src/policyengine/core/dataset_version.py similarity index 100% rename from src/policyengine/models/dataset_version.py rename to src/policyengine/core/dataset_version.py diff --git a/src/policyengine/models/dynamic.py b/src/policyengine/core/dynamic.py similarity index 100% rename from src/policyengine/models/dynamic.py rename to src/policyengine/core/dynamic.py diff --git a/src/policyengine/models/parameter.py b/src/policyengine/core/parameter.py similarity index 100% rename from src/policyengine/models/parameter.py rename to src/policyengine/core/parameter.py diff --git a/src/policyengine/models/parameter_value.py b/src/policyengine/core/parameter_value.py similarity index 100% rename from src/policyengine/models/parameter_value.py rename to src/policyengine/core/parameter_value.py diff --git a/src/policyengine/models/policy.py b/src/policyengine/core/policy.py similarity index 100% rename from src/policyengine/models/policy.py rename to src/policyengine/core/policy.py diff --git a/src/policyengine/models/simulation.py b/src/policyengine/core/simulation.py similarity index 100% rename from src/policyengine/models/simulation.py rename to src/policyengine/core/simulation.py diff --git a/src/policyengine/models/tax_benefit_model.py b/src/policyengine/core/tax_benefit_model.py similarity index 100% rename from src/policyengine/models/tax_benefit_model.py rename to src/policyengine/core/tax_benefit_model.py diff --git a/src/policyengine/models/tax_benefit_model_version.py b/src/policyengine/core/tax_benefit_model_version.py similarity index 100% rename from src/policyengine/models/tax_benefit_model_version.py rename to src/policyengine/core/tax_benefit_model_version.py diff --git a/src/policyengine/models/variable.py b/src/policyengine/core/variable.py similarity index 100% rename from src/policyengine/models/variable.py rename to src/policyengine/core/variable.py From 74915a06d4f7a7a94ba49f5419329bfe11afc094 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sat, 8 Nov 2025 13:47:07 +0000 Subject: [PATCH 15/35] Add dataset --- src/policyengine/core/__init__.py | 33 +---- src/policyengine/core/dataset.py | 1 + src/policyengine/core/dataset_version.py | 2 +- src/policyengine/core/tax_benefit_model.py | 7 + src/policyengine/tax_benefit_models/uk.py | 54 +++++++ tests/test_uk_dataset.py | 165 +++++++++++++++++++++ 6 files changed, 232 insertions(+), 30 deletions(-) create mode 100644 src/policyengine/tax_benefit_models/uk.py create mode 100644 tests/test_uk_dataset.py diff --git a/src/policyengine/core/__init__.py b/src/policyengine/core/__init__.py index 52f87c40..a958a071 100644 --- a/src/policyengine/core/__init__.py +++ b/src/policyengine/core/__init__.py @@ -1,36 +1,11 @@ -from .aggregate import Aggregate as Aggregate -from .aggregate import AggregateType as AggregateType -from .aggregate_change import AggregateChange as AggregateChange -from .baseline_parameter_value import ( - BaselineParameterValue as BaselineParameterValue, -) -from .variable import BaselineVariable as BaselineVariable +from .variable import Variable from .dataset import Dataset as Dataset from .dynamic import Dynamic as Dynamic -from .tax_benefit_model import Model as Model -from .tax_benefit_model_version import ModelVersion as ModelVersion +from .tax_benefit_model import TaxBenefitModel +from .tax_benefit_model_version import TaxBenefitModelVersion from .parameter import Parameter as Parameter from .parameter_value import ParameterValue as ParameterValue from .policy import Policy as Policy -from .policyengine_uk import ( - policyengine_uk_latest_version as policyengine_uk_latest_version, -) -from .policyengine_uk import ( - policyengine_uk_model as policyengine_uk_model, -) -from .policyengine_us import ( - policyengine_us_latest_version as policyengine_us_latest_version, -) -from .policyengine_us import ( - policyengine_us_model as policyengine_us_model, -) from .simulation import Simulation as Simulation -from .dataset_version import VersionedDataset as VersionedDataset - -# Rebuild models to handle circular references -from .aggregate import Aggregate -from .aggregate_change import AggregateChange +from .dataset_version import DatasetVersion from .simulation import Simulation -Aggregate.model_rebuild() -AggregateChange.model_rebuild() -Simulation.model_rebuild() diff --git a/src/policyengine/core/dataset.py b/src/policyengine/core/dataset.py index 4fc899b5..a2abf401 100644 --- a/src/policyengine/core/dataset.py +++ b/src/policyengine/core/dataset.py @@ -7,6 +7,7 @@ from .dataset_version import DatasetVersion + class Dataset(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) name: str diff --git a/src/policyengine/core/dataset_version.py b/src/policyengine/core/dataset_version.py index 7594c812..c6837818 100644 --- a/src/policyengine/core/dataset_version.py +++ b/src/policyengine/core/dataset_version.py @@ -10,6 +10,6 @@ class DatasetVersion(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) - dataset: Dataset + dataset: "Dataset" description: str tax_benefit_model: TaxBenefitModel = None diff --git a/src/policyengine/core/tax_benefit_model.py b/src/policyengine/core/tax_benefit_model.py index 5ad2ac84..c8c8a6fb 100644 --- a/src/policyengine/core/tax_benefit_model.py +++ b/src/policyengine/core/tax_benefit_model.py @@ -7,9 +7,16 @@ if TYPE_CHECKING: from .variable import Variable from .parameter import Parameter + from .simulation import Simulation + from .dataset import Dataset + from .policy import Policy + from .dynamic import Dynamic class TaxBenefitModel(BaseModel): id: str name: str description: str | None = None + + def run(self, simulation: "Simulation") -> "Simulation": + pass \ No newline at end of file diff --git a/src/policyengine/tax_benefit_models/uk.py b/src/policyengine/tax_benefit_models/uk.py new file mode 100644 index 00000000..ff29bf9e --- /dev/null +++ b/src/policyengine/tax_benefit_models/uk.py @@ -0,0 +1,54 @@ +from policyengine.core import * +from pydantic import BaseModel, Field, ConfigDict +import pandas as pd +from typing import Dict + +class YearData(BaseModel): + """Entity-level data for a single year.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + + person: pd.DataFrame + benunit: pd.DataFrame + household: pd.DataFrame + +class PolicyEngineUKDataset(Dataset): + """UK dataset with multi-year entity-level data.""" + data: Dict[int, YearData] = Field(default_factory=dict) + + def save(self, filepath: str) -> None: + """Save dataset to HDF5 file.""" + with pd.HDFStore(filepath, mode='w') as store: + for year, year_data in self.data.items(): + store[f'{year}/person'] = year_data.person + store[f'{year}/benunit'] = year_data.benunit + store[f'{year}/household'] = year_data.household + + @classmethod + def load(cls, filepath: str, **kwargs) -> 'PolicyEngineUKDataset': + """Load dataset from HDF5 file.""" + data = {} + with pd.HDFStore(filepath, mode='r') as store: + # Get all years from keys + years = set() + for key in store.keys(): + # Keys are like '/2025/person' + year = int(key.split('/')[1]) + years.add(year) + + # Load data for each year + for year in years: + data[year] = YearData( + person=store[f'{year}/person'], + benunit=store[f'{year}/benunit'], + household=store[f'{year}/household'] + ) + + # Create instance with data + return cls(data=data, filepath=filepath, **kwargs) + +class PolicyEngineUK(TaxBenefitModel): + pass + + +# Rebuild models to resolve forward references +PolicyEngineUKDataset.model_rebuild() diff --git a/tests/test_uk_dataset.py b/tests/test_uk_dataset.py new file mode 100644 index 00000000..64faa343 --- /dev/null +++ b/tests/test_uk_dataset.py @@ -0,0 +1,165 @@ +import pandas as pd +import tempfile +import os +from policyengine.tax_benefit_models.uk import PolicyEngineUKDataset, YearData + + +def test_save_and_load_single_year(): + """Test saving and loading a dataset with a single year.""" + # Create sample data + person_df = pd.DataFrame({ + 'person_id': [1, 2, 3], + 'age': [25, 30, 35], + 'income': [30000, 45000, 60000] + }) + + benunit_df = pd.DataFrame({ + 'benunit_id': [1, 2], + 'size': [2, 1], + 'total_income': [75000, 60000] + }) + + household_df = pd.DataFrame({ + 'household_id': [1], + 'num_people': [3], + 'rent': [1200] + }) + + # Create dataset + dataset = PolicyEngineUKDataset( + name='Test Dataset', + description='A test dataset', + filepath='test.h5', + data={ + 2025: YearData( + person=person_df, + benunit=benunit_df, + household=household_df + ) + } + ) + + # Save to temporary file + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, 'test_dataset.h5') + dataset.save(filepath) + + # Load it back + loaded = PolicyEngineUKDataset.load( + filepath, + name='Loaded Dataset', + description='Loaded from file' + ) + + # Verify data + assert 2025 in loaded.data + pd.testing.assert_frame_equal(loaded.data[2025].person, person_df) + pd.testing.assert_frame_equal(loaded.data[2025].benunit, benunit_df) + pd.testing.assert_frame_equal(loaded.data[2025].household, household_df) + + +def test_save_and_load_multiple_years(): + """Test saving and loading a dataset with multiple years.""" + # Create sample data for 2025 + person_2025 = pd.DataFrame({ + 'person_id': [1, 2], + 'age': [25, 30], + 'income': [30000, 45000] + }) + + benunit_2025 = pd.DataFrame({ + 'benunit_id': [1], + 'size': [2], + 'total_income': [75000] + }) + + household_2025 = pd.DataFrame({ + 'household_id': [1], + 'num_people': [2], + 'rent': [1200] + }) + + # Create sample data for 2026 + person_2026 = pd.DataFrame({ + 'person_id': [1, 2, 3], + 'age': [26, 31, 22], + 'income': [32000, 47000, 28000] + }) + + benunit_2026 = pd.DataFrame({ + 'benunit_id': [1, 2], + 'size': [2, 1], + 'total_income': [79000, 28000] + }) + + household_2026 = pd.DataFrame({ + 'household_id': [1], + 'num_people': [3], + 'rent': [1300] + }) + + # Create dataset with multiple years + dataset = PolicyEngineUKDataset( + name='Multi-year Dataset', + description='Dataset with multiple years', + filepath='test.h5', + data={ + 2025: YearData( + person=person_2025, + benunit=benunit_2025, + household=household_2025 + ), + 2026: YearData( + person=person_2026, + benunit=benunit_2026, + household=household_2026 + ) + } + ) + + # Save and load + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, 'multi_year_dataset.h5') + dataset.save(filepath) + + loaded = PolicyEngineUKDataset.load( + filepath, + name='Loaded Multi-year', + description='Loaded from file' + ) + + # Verify both years exist + assert 2025 in loaded.data + assert 2026 in loaded.data + + # Verify 2025 data + pd.testing.assert_frame_equal(loaded.data[2025].person, person_2025) + pd.testing.assert_frame_equal(loaded.data[2025].benunit, benunit_2025) + pd.testing.assert_frame_equal(loaded.data[2025].household, household_2025) + + # Verify 2026 data + pd.testing.assert_frame_equal(loaded.data[2026].person, person_2026) + pd.testing.assert_frame_equal(loaded.data[2026].benunit, benunit_2026) + pd.testing.assert_frame_equal(loaded.data[2026].household, household_2026) + + +def test_empty_dataset(): + """Test creating and saving an empty dataset.""" + dataset = PolicyEngineUKDataset( + name='Empty Dataset', + description='No data yet', + filepath='empty.h5', + data={} + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, 'empty_dataset.h5') + dataset.save(filepath) + + loaded = PolicyEngineUKDataset.load( + filepath, + name='Loaded Empty', + description='Empty dataset loaded' + ) + + assert len(loaded.data) == 0 From b4b12cfcf966ebd560b61cba3d368414a5e24c43 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 12 Nov 2025 11:48:50 +0000 Subject: [PATCH 16/35] Update --- src/policyengine/core/dataset.py | 1 + src/policyengine/core/simulation.py | 5 +- src/policyengine/core/tax_benefit_model.py | 3 +- src/policyengine/tax_benefit_models/uk.py | 59 +++++++++++++++++++--- tests/test_uk_dataset.py | 21 ++++---- 5 files changed, 70 insertions(+), 19 deletions(-) diff --git a/src/policyengine/core/dataset.py b/src/policyengine/core/dataset.py index a2abf401..a1f30172 100644 --- a/src/policyengine/core/dataset.py +++ b/src/policyengine/core/dataset.py @@ -14,4 +14,5 @@ class Dataset(BaseModel): description: str dataset_version: DatasetVersion | None = None filepath: str + is_output_dataset: bool = False tax_benefit_model: TaxBenefitModel = None diff --git a/src/policyengine/core/simulation.py b/src/policyengine/core/simulation.py index 490a0ad2..6e48944c 100644 --- a/src/policyengine/core/simulation.py +++ b/src/policyengine/core/simulation.py @@ -18,8 +18,9 @@ class Simulation(BaseModel): policy: Policy | None = None dynamic: Dynamic | None = None - dataset: Dataset | None = None + dataset: Dataset = None + year: int tax_benefit_model: TaxBenefitModel | None = None tax_benefit_model_version: TaxBenefitModelVersion | None = None - output_file_path: str | None = None + output_dataset: Dataset | None = None diff --git a/src/policyengine/core/tax_benefit_model.py b/src/policyengine/core/tax_benefit_model.py index c8c8a6fb..00fd9e41 100644 --- a/src/policyengine/core/tax_benefit_model.py +++ b/src/policyengine/core/tax_benefit_model.py @@ -15,8 +15,7 @@ class TaxBenefitModel(BaseModel): id: str - name: str description: str | None = None def run(self, simulation: "Simulation") -> "Simulation": - pass \ No newline at end of file + raise NotImplementedError("The TaxBenefitModel class must define a method to execute simulations.") \ No newline at end of file diff --git a/src/policyengine/tax_benefit_models/uk.py b/src/policyengine/tax_benefit_models/uk.py index ff29bf9e..3b896b49 100644 --- a/src/policyengine/tax_benefit_models/uk.py +++ b/src/policyengine/tax_benefit_models/uk.py @@ -23,9 +23,8 @@ def save(self, filepath: str) -> None: store[f'{year}/benunit'] = year_data.benunit store[f'{year}/household'] = year_data.household - @classmethod - def load(cls, filepath: str, **kwargs) -> 'PolicyEngineUKDataset': - """Load dataset from HDF5 file.""" + def load(self, filepath: str) -> None: + """Load dataset from HDF5 file into this instance.""" data = {} with pd.HDFStore(filepath, mode='r') as store: # Get all years from keys @@ -43,11 +42,59 @@ def load(cls, filepath: str, **kwargs) -> 'PolicyEngineUKDataset': household=store[f'{year}/household'] ) - # Create instance with data - return cls(data=data, filepath=filepath, **kwargs) + self.data = data + self.filepath = filepath class PolicyEngineUK(TaxBenefitModel): - pass + id: str = "policyengine-uk" + + def run(self, simulation: "Simulation") -> "Simulation": + from policyengine_uk import Microsimulation + from policyengine_uk.data import UKSingleYearDataset, UKMultiYearDataset + + assert isinstance(simulation.dataset, PolicyEngineUKDataset) + + dataset = simulation.dataset + dataset.load() + year_data = dataset.data[next(iter(dataset.data))] + input_data = UKSingleYearDataset( + person=year_data.person, + benunit=year_data.benunit, + household=year_data.household, + fiscal_year=next(iter(dataset.data)) + ) + microsim = Microsimulation(dataset=input_data) + + entity_variables = { + "person": ["person_id", "benunit_id", "household_id", "age", "employment_income", "person_weight"], + "benunit": ["benunit_id", "benunit_weight"], + "household": ["household_id", "household_weight", "hbai_household_net_income", "equiv_hbai_household_net_income"], + } + + data = { + "person": pd.DataFrame(), + "benunit": pd.DataFrame(), + "household": pd.DataFrame(), + } + + for entity, variables in entity_variables.items(): + for var in variables: + data[entity][var] = microsim.calculate(var, period=simulation.year, map_to=entity).values + + output_dataset = PolicyEngineUKDataset( + name=dataset.name, + description=dataset.description, + filepath=dataset.filepath, + data={ + simulation.year: YearData( + person=data["person"], + benunit=data["benunit"], + household=data["household"] + ) + } + ) + + output_dataset.save(dataset.filepath) # Rebuild models to resolve forward references diff --git a/tests/test_uk_dataset.py b/tests/test_uk_dataset.py index 64faa343..620de616 100644 --- a/tests/test_uk_dataset.py +++ b/tests/test_uk_dataset.py @@ -45,11 +45,12 @@ def test_save_and_load_single_year(): dataset.save(filepath) # Load it back - loaded = PolicyEngineUKDataset.load( - filepath, + loaded = PolicyEngineUKDataset( name='Loaded Dataset', - description='Loaded from file' + description='Loaded from file', + filepath=filepath ) + loaded.load(filepath) # Verify data assert 2025 in loaded.data @@ -122,11 +123,12 @@ def test_save_and_load_multiple_years(): filepath = os.path.join(tmpdir, 'multi_year_dataset.h5') dataset.save(filepath) - loaded = PolicyEngineUKDataset.load( - filepath, + loaded = PolicyEngineUKDataset( name='Loaded Multi-year', - description='Loaded from file' + description='Loaded from file', + filepath=filepath ) + loaded.load(filepath) # Verify both years exist assert 2025 in loaded.data @@ -156,10 +158,11 @@ def test_empty_dataset(): filepath = os.path.join(tmpdir, 'empty_dataset.h5') dataset.save(filepath) - loaded = PolicyEngineUKDataset.load( - filepath, + loaded = PolicyEngineUKDataset( name='Loaded Empty', - description='Empty dataset loaded' + description='Empty dataset loaded', + filepath=filepath ) + loaded.load(filepath) assert len(loaded.data) == 0 From 91d228aaaa9327f83e9328dae778ab56541ac34c Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 12 Nov 2025 14:07:20 +0000 Subject: [PATCH 17/35] Add progress --- CLAUDE.md | 7 + src/policyengine/core/__init__.py | 19 +- src/policyengine/core/dataset.py | 2 +- src/policyengine/core/dataset_version.py | 1 + src/policyengine/core/simulation.py | 7 +- src/policyengine/core/tax_benefit_model.py | 6 +- .../core/tax_benefit_model_version.py | 22 +- src/policyengine/core/variable.py | 1 - src/policyengine/tax_benefit_models/uk.py | 193 ++++++++++++---- src/policyengine/utils/__init__.py | 1 + src/policyengine/utils/dates.py | 24 ++ tests/test_uk_dataset.py | 208 +++++------------- 12 files changed, 279 insertions(+), 212 deletions(-) create mode 100644 src/policyengine/utils/__init__.py create mode 100644 src/policyengine/utils/dates.py diff --git a/CLAUDE.md b/CLAUDE.md index e69de29b..31c70c85 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -0,0 +1,7 @@ +# Claude notes + +Claude, please follow these always. These principles are aimed at preventing you from producing AI slop. + +1. British English, sentence case +2. No excessive duplication, keep code files as concise as possible to produce the same meaningful value. No excessive printing +3. Don't create multiple files for successive versions. Keep checking: have I added lots of intermediate files which are deprecated? Delete them if so, but ideally don't create them in the first place diff --git a/src/policyengine/core/__init__.py b/src/policyengine/core/__init__.py index a958a071..ef86cb16 100644 --- a/src/policyengine/core/__init__.py +++ b/src/policyengine/core/__init__.py @@ -1,11 +1,16 @@ from .variable import Variable -from .dataset import Dataset as Dataset -from .dynamic import Dynamic as Dynamic +from .dataset import Dataset +from .dynamic import Dynamic from .tax_benefit_model import TaxBenefitModel from .tax_benefit_model_version import TaxBenefitModelVersion -from .parameter import Parameter as Parameter -from .parameter_value import ParameterValue as ParameterValue -from .policy import Policy as Policy -from .simulation import Simulation as Simulation -from .dataset_version import DatasetVersion +from .parameter import Parameter +from .parameter_value import ParameterValue +from .policy import Policy from .simulation import Simulation +from .dataset_version import DatasetVersion + +# Rebuild models to resolve forward references +TaxBenefitModelVersion.model_rebuild() +Variable.model_rebuild() +Parameter.model_rebuild() +ParameterValue.model_rebuild() diff --git a/src/policyengine/core/dataset.py b/src/policyengine/core/dataset.py index a1f30172..7275d2df 100644 --- a/src/policyengine/core/dataset.py +++ b/src/policyengine/core/dataset.py @@ -7,7 +7,6 @@ from .dataset_version import DatasetVersion - class Dataset(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) name: str @@ -16,3 +15,4 @@ class Dataset(BaseModel): filepath: str is_output_dataset: bool = False tax_benefit_model: TaxBenefitModel = None + year: int diff --git a/src/policyengine/core/dataset_version.py b/src/policyengine/core/dataset_version.py index c6837818..29a0150b 100644 --- a/src/policyengine/core/dataset_version.py +++ b/src/policyengine/core/dataset_version.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from .tax_benefit_model import TaxBenefitModel + if TYPE_CHECKING: from .dataset import Dataset diff --git a/src/policyengine/core/simulation.py b/src/policyengine/core/simulation.py index 6e48944c..8006f1de 100644 --- a/src/policyengine/core/simulation.py +++ b/src/policyengine/core/simulation.py @@ -19,8 +19,9 @@ class Simulation(BaseModel): policy: Policy | None = None dynamic: Dynamic | None = None dataset: Dataset = None - year: int - tax_benefit_model: TaxBenefitModel | None = None - tax_benefit_model_version: TaxBenefitModelVersion | None = None + tax_benefit_model_version: TaxBenefitModelVersion = None output_dataset: Dataset | None = None + + def run(self): + self.tax_benefit_model_version.run(self) diff --git a/src/policyengine/core/tax_benefit_model.py b/src/policyengine/core/tax_benefit_model.py index 00fd9e41..255a21b1 100644 --- a/src/policyengine/core/tax_benefit_model.py +++ b/src/policyengine/core/tax_benefit_model.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import TYPE_CHECKING -from pydantic import BaseModel +from pydantic import BaseModel, Field if TYPE_CHECKING: from .variable import Variable @@ -11,11 +11,9 @@ from .dataset import Dataset from .policy import Policy from .dynamic import Dynamic + from .parameter_value import ParameterValue class TaxBenefitModel(BaseModel): id: str description: str | None = None - - def run(self, simulation: "Simulation") -> "Simulation": - raise NotImplementedError("The TaxBenefitModel class must define a method to execute simulations.") \ No newline at end of file diff --git a/src/policyengine/core/tax_benefit_model_version.py b/src/policyengine/core/tax_benefit_model_version.py index 1c702b06..0b6da7f3 100644 --- a/src/policyengine/core/tax_benefit_model_version.py +++ b/src/policyengine/core/tax_benefit_model_version.py @@ -4,6 +4,13 @@ from pydantic import BaseModel, Field from .tax_benefit_model import TaxBenefitModel +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .variable import Variable + from .parameter import Parameter + from .parameter_value import ParameterValue + from .simulation import Simulation class TaxBenefitModelVersion(BaseModel): @@ -11,4 +18,17 @@ class TaxBenefitModelVersion(BaseModel): model: TaxBenefitModel version: str description: str | None = None - created_at: datetime = Field(default_factory=datetime.now) + created_at: datetime | None = Field(default_factory=datetime.utcnow) + + variables: list["Variable"] = Field(default_factory=list) + parameters: list["Parameter"] = Field(default_factory=list) + parameter_values: list["ParameterValue"] = Field(default_factory=list) + + def run(self, simulation: "Simulation") -> "Simulation": + raise NotImplementedError( + "The TaxBenefitModel class must define a method to execute simulations." + ) + + def __repr__(self) -> str: + # Give the id and version, and the number of variables, parameters, parameter values + return f"" diff --git a/src/policyengine/core/variable.py b/src/policyengine/core/variable.py index a0b8b246..a8da5913 100644 --- a/src/policyengine/core/variable.py +++ b/src/policyengine/core/variable.py @@ -7,6 +7,5 @@ class Variable(BaseModel): id: str tax_benefit_model_version: TaxBenefitModelVersion entity: str - name: str | None = None description: str | None = None data_type: type = None diff --git a/src/policyengine/tax_benefit_models/uk.py b/src/policyengine/tax_benefit_models/uk.py index 3b896b49..86b8b065 100644 --- a/src/policyengine/tax_benefit_models/uk.py +++ b/src/policyengine/tax_benefit_models/uk.py @@ -2,73 +2,162 @@ from pydantic import BaseModel, Field, ConfigDict import pandas as pd from typing import Dict +import datetime +import requests +from policyengine.utils import parse_safe_date +from pathlib import Path +from importlib.metadata import version -class YearData(BaseModel): + +class UKYearData(BaseModel): """Entity-level data for a single year.""" + model_config = ConfigDict(arbitrary_types_allowed=True) person: pd.DataFrame benunit: pd.DataFrame household: pd.DataFrame + class PolicyEngineUKDataset(Dataset): """UK dataset with multi-year entity-level data.""" - data: Dict[int, YearData] = Field(default_factory=dict) - def save(self, filepath: str) -> None: + data: UKYearData | None = None + + def save(self) -> None: """Save dataset to HDF5 file.""" - with pd.HDFStore(filepath, mode='w') as store: - for year, year_data in self.data.items(): - store[f'{year}/person'] = year_data.person - store[f'{year}/benunit'] = year_data.benunit - store[f'{year}/household'] = year_data.household + filepath = self.filepath + with pd.HDFStore(filepath, mode="w") as store: + store["person"] = self.data.person + store["benunit"] = self.data.benunit + store["household"] = self.data.household - def load(self, filepath: str) -> None: + def load(self) -> None: """Load dataset from HDF5 file into this instance.""" - data = {} - with pd.HDFStore(filepath, mode='r') as store: - # Get all years from keys - years = set() - for key in store.keys(): - # Keys are like '/2025/person' - year = int(key.split('/')[1]) - years.add(year) - - # Load data for each year - for year in years: - data[year] = YearData( - person=store[f'{year}/person'], - benunit=store[f'{year}/benunit'], - household=store[f'{year}/household'] - ) + filepath = self.filepath + with pd.HDFStore(filepath, mode="r") as store: + self.data = UKYearData( + person=store["person"], + benunit=store["benunit"], + household=store["household"], + ) + + def __repr__(self) -> str: + if self.data is None: + return f"" + else: + n_people = len(self.data.person) + n_benunits = len(self.data.benunit) + n_households = len(self.data.household) + return f"" - self.data = data - self.filepath = filepath class PolicyEngineUK(TaxBenefitModel): id: str = "policyengine-uk" + description: str = "The UK's open-source dynamic tax and benefit microsimulation model maintained by PolicyEngine." + + +uk_model = PolicyEngineUK() + +pkg_version = version("policyengine-uk") + +# Get published time from PyPI +response = requests.get("https://pypi.org/pypi/policyengine-uk/json") +data = response.json() +upload_time = data["releases"][pkg_version][0]["upload_time_iso_8601"] + + +class PolicyEngineUKLatest(TaxBenefitModelVersion): + model: TaxBenefitModel = uk_model + version: str = pkg_version + created_at: datetime.datetime = datetime.datetime.fromisoformat( + upload_time + ) + + def __init__(self, **kwargs: dict): + super().__init__(**kwargs) + from policyengine_uk.system import system + + self.id = f"{self.model.id}@{self.version}" + + self.variables = [] + for var_obj in system.variables.values(): + variable = Variable( + id=self.id + "-" + var_obj.name, + tax_benefit_model_version=self, + entity=var_obj.entity.key, + description=var_obj.documentation, + data_type=var_obj.value_type, + ) + self.variables.append(variable) + + self.parameters = [] + from policyengine_core.parameters import Parameter as CoreParameter + + for param_node in system.parameters.get_descendants(): + if isinstance(param_node, CoreParameter): + parameter = Parameter( + id=self.id + "-" + param_node.name, + name=param_node.name, + tax_benefit_model_version=self, + description=param_node.description, + data_type=type( + param_node(2025) + ), # Example year to infer type + unit=param_node.metadata.get("unit"), + ) + self.parameters.append(parameter) + + for i in range(len(param_node.values_list)): + param_at_instant = param_node.values_list[i] + if i + 1 < len(param_node.values_list): + next_instant = param_node.values_list[i + 1] + else: + next_instant = None + parameter_value = ParameterValue( + parameter=parameter, + start_date=parse_safe_date( + param_at_instant.instant_str + ), + end_date=parse_safe_date(next_instant.instant_str) + if next_instant + else None, + value=param_at_instant.value, + ) + self.parameter_values.append(parameter_value) def run(self, simulation: "Simulation") -> "Simulation": from policyengine_uk import Microsimulation - from policyengine_uk.data import UKSingleYearDataset, UKMultiYearDataset + from policyengine_uk.data import UKSingleYearDataset assert isinstance(simulation.dataset, PolicyEngineUKDataset) dataset = simulation.dataset dataset.load() - year_data = dataset.data[next(iter(dataset.data))] input_data = UKSingleYearDataset( - person=year_data.person, - benunit=year_data.benunit, - household=year_data.household, - fiscal_year=next(iter(dataset.data)) + person=dataset.data.person, + benunit=dataset.data.benunit, + household=dataset.data.household, + fiscal_year=dataset.year, ) microsim = Microsimulation(dataset=input_data) - + entity_variables = { - "person": ["person_id", "benunit_id", "household_id", "age", "employment_income", "person_weight"], + "person": [ + "person_id", + "benunit_id", + "household_id", + "age", + "employment_income", + "person_weight", + ], "benunit": ["benunit_id", "benunit_weight"], - "household": ["household_id", "household_weight", "hbai_household_net_income", "equiv_hbai_household_net_income"], + "household": [ + "household_id", + "household_weight", + "hbai_household_net_income", + "equiv_hbai_household_net_income", + ], } data = { @@ -79,23 +168,31 @@ def run(self, simulation: "Simulation") -> "Simulation": for entity, variables in entity_variables.items(): for var in variables: - data[entity][var] = microsim.calculate(var, period=simulation.year, map_to=entity).values - - output_dataset = PolicyEngineUKDataset( + data[entity][var] = microsim.calculate( + var, period=simulation.dataset.year, map_to=entity + ).values + + simulation.output_dataset = PolicyEngineUKDataset( name=dataset.name, description=dataset.description, - filepath=dataset.filepath, - data={ - simulation.year: YearData( - person=data["person"], - benunit=data["benunit"], - household=data["household"] - ) - } + filepath=str( + Path(simulation.dataset.filepath).parent + / (simulation.id + ".h5") + ), + year=simulation.dataset.year, + is_output_dataset=True, + data=UKYearData( + person=data["person"], + benunit=data["benunit"], + household=data["household"], + ), ) - output_dataset.save(dataset.filepath) + simulation.output_dataset.save() # Rebuild models to resolve forward references PolicyEngineUKDataset.model_rebuild() +PolicyEngineUKLatest.model_rebuild() + +uk_latest = PolicyEngineUKLatest() diff --git a/src/policyengine/utils/__init__.py b/src/policyengine/utils/__init__.py new file mode 100644 index 00000000..6761220d --- /dev/null +++ b/src/policyengine/utils/__init__.py @@ -0,0 +1 @@ +from .dates import parse_safe_date diff --git a/src/policyengine/utils/dates.py b/src/policyengine/utils/dates.py new file mode 100644 index 00000000..6bcacab1 --- /dev/null +++ b/src/policyengine/utils/dates.py @@ -0,0 +1,24 @@ +from datetime import datetime + + +def parse_safe_date(date_string: str) -> datetime: + """ + Parse a YYYY-MM-DD date string and ensure the year is at least 1. + + Args: + date_string: Date string in YYYY-MM-DD format + + Returns: + Safe datetime object with year >= 1 + """ + try: + date_string = date_string.replace("0000-", "0001-") + date_obj = datetime.strptime(date_string, "%Y-%m-%d") + if date_obj.year < 1: + # Replace year 0 or negative years with year 1 + return date_obj.replace(year=1) + return date_obj + except ValueError: + raise ValueError( + f"Invalid date format: {date_string}. Expected YYYY-MM-DD" + ) diff --git a/tests/test_uk_dataset.py b/tests/test_uk_dataset.py index 620de616..d4e2ed83 100644 --- a/tests/test_uk_dataset.py +++ b/tests/test_uk_dataset.py @@ -1,168 +1,82 @@ import pandas as pd import tempfile import os -from policyengine.tax_benefit_models.uk import PolicyEngineUKDataset, YearData +from policyengine.core import * +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, +) -def test_save_and_load_single_year(): - """Test saving and loading a dataset with a single year.""" - # Create sample data - person_df = pd.DataFrame({ - 'person_id': [1, 2, 3], - 'age': [25, 30, 35], - 'income': [30000, 45000, 60000] - }) - - benunit_df = pd.DataFrame({ - 'benunit_id': [1, 2], - 'size': [2, 1], - 'total_income': [75000, 60000] - }) - - household_df = pd.DataFrame({ - 'household_id': [1], - 'num_people': [3], - 'rent': [1200] - }) - - # Create dataset - dataset = PolicyEngineUKDataset( - name='Test Dataset', - description='A test dataset', - filepath='test.h5', - data={ - 2025: YearData( - person=person_df, - benunit=benunit_df, - household=household_df - ) - } - ) - - # Save to temporary file - with tempfile.TemporaryDirectory() as tmpdir: - filepath = os.path.join(tmpdir, 'test_dataset.h5') - dataset.save(filepath) - - # Load it back - loaded = PolicyEngineUKDataset( - name='Loaded Dataset', - description='Loaded from file', - filepath=filepath - ) - loaded.load(filepath) - - # Verify data - assert 2025 in loaded.data - pd.testing.assert_frame_equal(loaded.data[2025].person, person_df) - pd.testing.assert_frame_equal(loaded.data[2025].benunit, benunit_df) - pd.testing.assert_frame_equal(loaded.data[2025].household, household_df) - - -def test_save_and_load_multiple_years(): - """Test saving and loading a dataset with multiple years.""" - # Create sample data for 2025 - person_2025 = pd.DataFrame({ - 'person_id': [1, 2], - 'age': [25, 30], - 'income': [30000, 45000] - }) - - benunit_2025 = pd.DataFrame({ - 'benunit_id': [1], - 'size': [2], - 'total_income': [75000] - }) +def test_imports(): + """Test that basic imports work.""" + # Verify classes are importable + assert PolicyEngineUKDataset is not None + assert UKYearData is not None + assert Dataset is not None + assert TaxBenefitModel is not None - household_2025 = pd.DataFrame({ - 'household_id': [1], - 'num_people': [2], - 'rent': [1200] - }) - # Create sample data for 2026 - person_2026 = pd.DataFrame({ - 'person_id': [1, 2, 3], - 'age': [26, 31, 22], - 'income': [32000, 47000, 28000] - }) +def test_uk_latest_instantiation(): + """Test that uk_latest can be instantiated without errors.""" + from policyengine.tax_benefit_models.uk import uk_latest - benunit_2026 = pd.DataFrame({ - 'benunit_id': [1, 2], - 'size': [2, 1], - 'total_income': [79000, 28000] - }) + assert uk_latest is not None + assert uk_latest.version is not None + assert uk_latest.model is not None + assert uk_latest.created_at is not None + assert ( + len(uk_latest.variables) > 0 + ) # Should have variables from policyengine-uk - household_2026 = pd.DataFrame({ - 'household_id': [1], - 'num_people': [3], - 'rent': [1300] - }) - # Create dataset with multiple years - dataset = PolicyEngineUKDataset( - name='Multi-year Dataset', - description='Dataset with multiple years', - filepath='test.h5', - data={ - 2025: YearData( - person=person_2025, - benunit=benunit_2025, - household=household_2025 - ), - 2026: YearData( - person=person_2026, - benunit=benunit_2026, - household=household_2026 - ) +def test_save_and_load_single_year(): + """Test saving and loading a dataset with a single year.""" + # Create sample data + person_df = pd.DataFrame( + { + "person_id": [1, 2, 3], + "age": [25, 30, 35], + "income": [30000, 45000, 60000], } ) - # Save and load - with tempfile.TemporaryDirectory() as tmpdir: - filepath = os.path.join(tmpdir, 'multi_year_dataset.h5') - dataset.save(filepath) - - loaded = PolicyEngineUKDataset( - name='Loaded Multi-year', - description='Loaded from file', - filepath=filepath - ) - loaded.load(filepath) - - # Verify both years exist - assert 2025 in loaded.data - assert 2026 in loaded.data - - # Verify 2025 data - pd.testing.assert_frame_equal(loaded.data[2025].person, person_2025) - pd.testing.assert_frame_equal(loaded.data[2025].benunit, benunit_2025) - pd.testing.assert_frame_equal(loaded.data[2025].household, household_2025) - - # Verify 2026 data - pd.testing.assert_frame_equal(loaded.data[2026].person, person_2026) - pd.testing.assert_frame_equal(loaded.data[2026].benunit, benunit_2026) - pd.testing.assert_frame_equal(loaded.data[2026].household, household_2026) - + benunit_df = pd.DataFrame( + {"benunit_id": [1, 2], "size": [2, 1], "total_income": [75000, 60000]} + ) -def test_empty_dataset(): - """Test creating and saving an empty dataset.""" - dataset = PolicyEngineUKDataset( - name='Empty Dataset', - description='No data yet', - filepath='empty.h5', - data={} + household_df = pd.DataFrame( + {"household_id": [1], "num_people": [3], "rent": [1200]} ) + # Create dataset with tempfile.TemporaryDirectory() as tmpdir: - filepath = os.path.join(tmpdir, 'empty_dataset.h5') - dataset.save(filepath) + filepath = os.path.join(tmpdir, "test_dataset.h5") + + dataset = PolicyEngineUKDataset( + name="Test Dataset", + description="A test dataset", + filepath=filepath, + year=2025, + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), + ) + + # Save to file + dataset.save() + # Load it back loaded = PolicyEngineUKDataset( - name='Loaded Empty', - description='Empty dataset loaded', - filepath=filepath + name="Loaded Dataset", + description="Loaded from file", + filepath=filepath, + year=2025, ) - loaded.load(filepath) + loaded.load() - assert len(loaded.data) == 0 + # Verify data + assert loaded.year == 2025 + pd.testing.assert_frame_equal(loaded.data.person, person_df) + pd.testing.assert_frame_equal(loaded.data.benunit, benunit_df) + pd.testing.assert_frame_equal(loaded.data.household, household_df) From 702c79e41b5d913afafa43b799aa21bb258ea829 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 12 Nov 2025 17:43:04 +0000 Subject: [PATCH 18/35] Update --- .gitignore | 6 +- CLAUDE.md | 10 + Makefile | 8 +- examples/income_bands.py | 167 ++++++++ pyproject.toml | 3 + src/policyengine/core/dataset.py | 21 +- src/policyengine/core/simulation.py | 7 +- src/policyengine/core/variable.py | 3 + src/policyengine/outputs/aggregate.py | 94 +++++ src/policyengine/tax_benefit_models/uk.py | 277 ++++++++++-- tests/test_aggregate.py | 490 ++++++++++++++++++++++ tests/test_entity_mapping.py | 384 +++++++++++++++++ tests/test_uk_dataset.py | 54 ++- 13 files changed, 1476 insertions(+), 48 deletions(-) create mode 100644 examples/income_bands.py create mode 100644 src/policyengine/outputs/aggregate.py create mode 100644 tests/test_aggregate.py create mode 100644 tests/test_entity_mapping.py diff --git a/.gitignore b/.gitignore index dc335293..69dfad02 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,7 @@ **/*.db **/__pycache__ **/*.egg-info +**/*.h5 +*.ipynb _build/ -simulations/ -test.* -supabase/ .env -**/review.md diff --git a/CLAUDE.md b/CLAUDE.md index 31c70c85..be48ac80 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -5,3 +5,13 @@ Claude, please follow these always. These principles are aimed at preventing you 1. British English, sentence case 2. No excessive duplication, keep code files as concise as possible to produce the same meaningful value. No excessive printing 3. Don't create multiple files for successive versions. Keep checking: have I added lots of intermediate files which are deprecated? Delete them if so, but ideally don't create them in the first place + +## MicroDataFrame + +A pandas DataFrame that automatically handles weights for survey microdata. Key features: + +- Create with `MicroDataFrame(df, weights='weight_column')` +- All aggregations (sum, mean, etc.) automatically weight results +- Each column is a MicroSeries with weighted operations +- Use `.groupby()` for weighted group statistics +- Built-in poverty analysis: `.poverty_rate()`, `.poverty_gap()` diff --git a/Makefile b/Makefile index 931fccdd..f1e4b163 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,13 @@ format: ruff format . clean: - rm -rf **/__pycache__ _build **/_build .pytest_cache .ruff_cache **/*.egg-info **/*.pyc + find . -not -path "./.venv/*" -type d -name "__pycache__" -exec rm -rf {} + + find . -not -path "./.venv/*" -type d -name "_build" -exec rm -rf {} + + find . -not -path "./.venv/*" -type d -name ".pytest_cache" -exec rm -rf {} + + find . -not -path "./.venv/*" -type d -name ".ruff_cache" -exec rm -rf {} + + find . -not -path "./.venv/*" -type d -name "*.egg-info" -exec rm -rf {} + + find . -not -path "./.venv/*" -type f -name "*.pyc" -delete + find . -not -path "./.venv/*" -type f -name "*.h5" -delete changelog: build-changelog changelog.yaml --output changelog.yaml --update-last-date --start-from 1.0.0 --append-file changelog_entry.yaml diff --git a/examples/income_bands.py b/examples/income_bands.py new file mode 100644 index 00000000..bd665d1f --- /dev/null +++ b/examples/income_bands.py @@ -0,0 +1,167 @@ +"""Example: Calculate total employment income by income band. + +This script demonstrates: +1. Creating a dataset with randomly sampled incomes (exponential distribution) +2. Using Aggregate to calculate statistics within income bands +3. Filtering with geq/leq constraints +4. Visualising results with Plotly + +Run: python examples/income_bands.py +""" + +import numpy as np +import pandas as pd +from microdf import MicroDataFrame +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from policyengine.core import Simulation +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, + uk_latest, +) +from policyengine.outputs.aggregate import Aggregate, AggregateType + +# Create sample data with random incomes (simplified - no simulation needed) +np.random.seed(42) +n_people = 1000 + +person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": range(1, n_people + 1), + "benunit_id": range(1, n_people + 1), + "household_id": range(1, n_people + 1), + "age": np.random.randint(18, 70, n_people), + "employment_income": np.random.exponential(35000, n_people), + "person_weight": np.ones(n_people), + } + ), + weights="person_weight", +) + +benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": range(1, n_people + 1), + "benunit_weight": np.ones(n_people), + } + ), + weights="benunit_weight", +) + +household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": range(1, n_people + 1), + "household_weight": np.ones(n_people), + } + ), + weights="household_weight", +) + +# Create dataset (this serves as our output dataset) +dataset = PolicyEngineUKDataset( + name="Sample Dataset", + description="Random sample for testing", + filepath="./sample_data.h5", + year=2024, + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), +) + +# Create simulation with dataset as output +simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + output_dataset=dataset, +) + +# Calculate total income by 10k bands +bands = [] +totals = [] +counts = [] + +for lower in range(0, 100000, 10000): + upper = lower + 10000 + + agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.SUM, + filter_variable="employment_income", + filter_variable_geq=lower, + filter_variable_leq=upper, + ) + agg.run() + + count_agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.COUNT, + filter_variable="employment_income", + filter_variable_geq=lower, + filter_variable_leq=upper, + ) + count_agg.run() + + bands.append(f"£{lower // 1000}k-£{upper // 1000}k") + totals.append(agg.result) + counts.append(count_agg.result) + +# Calculate 100k+ band +agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.SUM, + filter_variable="employment_income", + filter_variable_geq=100000, +) +agg.run() + +count_agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.COUNT, + filter_variable="employment_income", + filter_variable_geq=100000, +) +count_agg.run() + +bands.append("£100k+") +totals.append(agg.result) +counts.append(count_agg.result) + +# Create chart +fig = make_subplots( + rows=1, + cols=2, + subplot_titles=("Total income by band", "Population by band"), + specs=[[{"type": "bar"}, {"type": "bar"}]], +) + +fig.add_trace( + go.Bar(x=bands, y=totals, name="Total income", marker_color="lightblue"), + row=1, + col=1, +) + +fig.add_trace( + go.Bar(x=bands, y=counts, name="Count", marker_color="lightgreen"), + row=1, + col=2, +) + +fig.update_xaxes(title_text="Income band", row=1, col=1) +fig.update_xaxes(title_text="Income band", row=1, col=2) +fig.update_yaxes(title_text="Total income (£)", row=1, col=1) +fig.update_yaxes(title_text="Number of people", row=1, col=2) + +fig.update_layout( + title_text="Employment income distribution", + showlegend=False, + height=400, +) + +fig.show() diff --git a/pyproject.toml b/pyproject.toml index 251a2d68..58df3c7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,9 @@ addopts = "-v" testpaths = [ "tests", ] +filterwarnings = [ + "ignore::pydantic.warnings.PydanticDeprecatedSince20", +] [tool.black] line-length = 79 diff --git a/src/policyengine/core/dataset.py b/src/policyengine/core/dataset.py index 7275d2df..34997d01 100644 --- a/src/policyengine/core/dataset.py +++ b/src/policyengine/core/dataset.py @@ -1,13 +1,28 @@ -from typing import Any from uuid import uuid4 -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from .tax_benefit_model import TaxBenefitModel from .dataset_version import DatasetVersion class Dataset(BaseModel): + """Base class for datasets. + + The data field contains entity-level data as a BaseModel with DataFrame fields. + + Example: + class YearData(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + person: pd.DataFrame + household: pd.DataFrame + + class MyDataset(Dataset): + data: YearData | None = None + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + id: str = Field(default_factory=lambda: str(uuid4())) name: str description: str @@ -16,3 +31,5 @@ class Dataset(BaseModel): is_output_dataset: bool = False tax_benefit_model: TaxBenefitModel = None year: int + + data: BaseModel | None = None diff --git a/src/policyengine/core/simulation.py b/src/policyengine/core/simulation.py index 8006f1de..66ece502 100644 --- a/src/policyengine/core/simulation.py +++ b/src/policyengine/core/simulation.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any +from typing import Any, Dict, List from uuid import uuid4 from pydantic import BaseModel, Field @@ -23,5 +23,10 @@ class Simulation(BaseModel): tax_benefit_model_version: TaxBenefitModelVersion = None output_dataset: Dataset | None = None + variables: Dict[str, List[str]] | None = Field( + default=None, + description="Optional dictionary mapping entity names to lists of variable names to calculate. If None, uses model defaults.", + ) + def run(self): self.tax_benefit_model_version.run(self) diff --git a/src/policyengine/core/variable.py b/src/policyengine/core/variable.py index a8da5913..2428a0a7 100644 --- a/src/policyengine/core/variable.py +++ b/src/policyengine/core/variable.py @@ -1,11 +1,14 @@ from pydantic import BaseModel +from typing import Any from .tax_benefit_model_version import TaxBenefitModelVersion class Variable(BaseModel): id: str + name: str tax_benefit_model_version: TaxBenefitModelVersion entity: str description: str | None = None data_type: type = None + possible_values: list[Any] | None = None diff --git a/src/policyengine/outputs/aggregate.py b/src/policyengine/outputs/aggregate.py new file mode 100644 index 00000000..5c0c08c1 --- /dev/null +++ b/src/policyengine/outputs/aggregate.py @@ -0,0 +1,94 @@ +from pydantic import BaseModel, Field +from policyengine.core import * +from enum import Enum +from typing import Any + + +class AggregateType(str, Enum): + SUM = "sum" + MEAN = "mean" + COUNT = "count" + + +class Aggregate(BaseModel): + simulation: Simulation + variable: str + aggregate_type: AggregateType + entity: str | None = None + + filter_variable: str | None = None + filter_variable_eq: Any | None = None + filter_variable_leq: Any | None = None + filter_variable_geq: Any | None = None + filter_variable_describes_quantiles: bool = False + + result: Any | None = None + + def run(self): + # Get variable object + var_obj = next( + v + for v in self.simulation.tax_benefit_model_version.variables + if v.name == self.variable + ) + + # Get the target entity data + target_entity = self.entity or var_obj.entity + data = getattr(self.simulation.output_dataset.data, target_entity) + + # Map variable to target entity if needed + if var_obj.entity != target_entity: + mapped = self.simulation.output_dataset.data.map_to_entity( + var_obj.entity, target_entity + ) + series = mapped[self.variable] + else: + series = data[self.variable] + + # Apply filters + if self.filter_variable is not None: + filter_var_obj = next( + v + for v in self.simulation.tax_benefit_model_version.variables + if v.name == self.filter_variable + ) + + if filter_var_obj.entity != target_entity: + filter_mapped = ( + self.simulation.output_dataset.data.map_to_entity( + filter_var_obj.entity, target_entity + ) + ) + filter_series = filter_mapped[self.filter_variable] + else: + filter_series = data[self.filter_variable] + + if self.filter_variable_describes_quantiles: + if self.filter_variable_eq is not None: + threshold = filter_series.quantile(self.filter_variable_eq) + series = series[filter_series <= threshold] + if self.filter_variable_leq is not None: + threshold = filter_series.quantile( + self.filter_variable_leq + ) + series = series[filter_series <= threshold] + if self.filter_variable_geq is not None: + threshold = filter_series.quantile( + self.filter_variable_geq + ) + series = series[filter_series >= threshold] + else: + if self.filter_variable_eq is not None: + series = series[filter_series == self.filter_variable_eq] + if self.filter_variable_leq is not None: + series = series[filter_series <= self.filter_variable_leq] + if self.filter_variable_geq is not None: + series = series[filter_series >= self.filter_variable_geq] + + # Aggregate + if self.aggregate_type == AggregateType.SUM: + self.result = series.sum() + elif self.aggregate_type == AggregateType.MEAN: + self.result = series.mean() + elif self.aggregate_type == AggregateType.COUNT: + self.result = series.count() diff --git a/src/policyengine/tax_benefit_models/uk.py b/src/policyengine/tax_benefit_models/uk.py index 86b8b065..b6e14aaf 100644 --- a/src/policyengine/tax_benefit_models/uk.py +++ b/src/policyengine/tax_benefit_models/uk.py @@ -7,6 +7,7 @@ from policyengine.utils import parse_safe_date from pathlib import Path from importlib.metadata import version +from microdf import MicroDataFrame class UKYearData(BaseModel): @@ -14,9 +15,123 @@ class UKYearData(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - person: pd.DataFrame - benunit: pd.DataFrame - household: pd.DataFrame + person: MicroDataFrame + benunit: MicroDataFrame + household: MicroDataFrame + + def map_to_entity( + self, source_entity: str, target_entity: str, columns: list[str] = None + ) -> MicroDataFrame: + """Map data from source entity to target entity using join keys. + + Args: + source_entity (str): The source entity name ('person', 'benunit', 'household'). + target_entity (str): The target entity name ('person', 'benunit', 'household'). + columns (list[str], optional): List of column names to map. If None, maps all columns. + + Returns: + MicroDataFrame: The mapped data at the target entity level. + + Raises: + ValueError: If source or target entity is invalid. + """ + valid_entities = {"person", "benunit", "household"} + if source_entity not in valid_entities: + raise ValueError( + f"Invalid source entity '{source_entity}'. Must be one of {valid_entities}" + ) + if target_entity not in valid_entities: + raise ValueError( + f"Invalid target entity '{target_entity}'. Must be one of {valid_entities}" + ) + + # Get source data + source_df = getattr(self, source_entity) + if columns: + # Select only requested columns (keep join keys) + join_keys = {"person_id", "benunit_id", "household_id"} + cols_to_keep = list( + set(columns) | (join_keys & set(source_df.columns)) + ) + source_df = source_df[cols_to_keep] + + # Determine weight column for target entity + weight_col_map = { + "person": "person_weight", + "benunit": "benunit_weight", + "household": "household_weight", + } + target_weight = weight_col_map[target_entity] + + # Same entity - return as is + if source_entity == target_entity: + return MicroDataFrame( + pd.DataFrame(source_df), weights=target_weight + ) + + # Map to different entity + target_df = getattr(self, target_entity) + + # Person -> Benunit + if source_entity == "person" and target_entity == "benunit": + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on="benunit_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Person -> Household + elif source_entity == "person" and target_entity == "household": + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on="household_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Benunit -> Person + elif source_entity == "benunit" and target_entity == "person": + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on="benunit_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Benunit -> Household + elif source_entity == "benunit" and target_entity == "household": + # Need to go through person to link benunit and household + person_link = pd.DataFrame(self.person)[ + ["benunit_id", "household_id"] + ].drop_duplicates() + source_with_hh = pd.DataFrame(source_df).merge( + person_link, on="benunit_id", how="left" + ) + result = pd.DataFrame(target_df).merge( + source_with_hh, on="household_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Household -> Person + elif source_entity == "household" and target_entity == "person": + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on="household_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Household -> Benunit + elif source_entity == "household" and target_entity == "benunit": + # Need to go through person to link household and benunit + person_link = pd.DataFrame(self.person)[ + ["benunit_id", "household_id"] + ].drop_duplicates() + source_with_bu = pd.DataFrame(source_df).merge( + person_link, on="household_id", how="left" + ) + result = pd.DataFrame(target_df).merge( + source_with_bu, on="benunit_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + else: + raise ValueError( + f"Unsupported mapping from {source_entity} to {target_entity}" + ) class PolicyEngineUKDataset(Dataset): @@ -24,22 +139,37 @@ class PolicyEngineUKDataset(Dataset): data: UKYearData | None = None + def __init__(self, **kwargs: dict): + super().__init__(**kwargs) + + # Make sure we are synchronised between in-memory and storage, at least on initialisation. + if "data" in kwargs: + self.save() + elif "filepath" in kwargs: + self.load() + def save(self) -> None: """Save dataset to HDF5 file.""" filepath = self.filepath with pd.HDFStore(filepath, mode="w") as store: - store["person"] = self.data.person - store["benunit"] = self.data.benunit - store["household"] = self.data.household + store["person"] = pd.DataFrame(self.data.person) + store["benunit"] = pd.DataFrame(self.data.benunit) + store["household"] = pd.DataFrame(self.data.household) def load(self) -> None: """Load dataset from HDF5 file into this instance.""" filepath = self.filepath with pd.HDFStore(filepath, mode="r") as store: self.data = UKYearData( - person=store["person"], - benunit=store["benunit"], - household=store["household"], + person=MicroDataFrame( + store["person"], weights="person_weight" + ), + benunit=MicroDataFrame( + store["benunit"], weights="benunit_weight" + ), + household=MicroDataFrame( + store["household"], weights="household_weight" + ), ) def __repr__(self) -> str: @@ -77,6 +207,7 @@ class PolicyEngineUKLatest(TaxBenefitModelVersion): def __init__(self, **kwargs: dict): super().__init__(**kwargs) from policyengine_uk.system import system + from policyengine_core.enums import Enum self.id = f"{self.model.id}@{self.version}" @@ -84,11 +215,24 @@ def __init__(self, **kwargs: dict): for var_obj in system.variables.values(): variable = Variable( id=self.id + "-" + var_obj.name, + name=var_obj.name, tax_benefit_model_version=self, entity=var_obj.entity.key, description=var_obj.documentation, - data_type=var_obj.value_type, + data_type=var_obj.value_type + if var_obj.value_type is not Enum + else str, ) + if ( + hasattr(var_obj, "possible_values") + and var_obj.possible_values is not None + ): + variable.possible_values = list( + map( + lambda x: x.name, + var_obj.possible_values._value2member_map_.values(), + ) + ) self.variables.append(variable) self.parameters = [] @@ -142,23 +286,92 @@ def run(self, simulation: "Simulation") -> "Simulation": ) microsim = Microsimulation(dataset=input_data) - entity_variables = { - "person": [ - "person_id", - "benunit_id", - "household_id", - "age", - "employment_income", - "person_weight", - ], - "benunit": ["benunit_id", "benunit_weight"], - "household": [ - "household_id", - "household_weight", - "hbai_household_net_income", - "equiv_hbai_household_net_income", - ], - } + if ( + simulation.policy + and simulation.policy.simulation_modifier is not None + ): + simulation.policy.simulation_modifier(microsim) + + if ( + simulation.dynamic + and simulation.dynamic.simulation_modifier is not None + ): + simulation.dynamic.simulation_modifier(microsim) + + # Allow custom variable selection, or use defaults + if simulation.variables is not None: + entity_variables = simulation.variables + else: + # Default comprehensive variable set + entity_variables = { + "person": [ + # IDs and weights + "person_id", + "benunit_id", + "household_id", + "person_weight", + # Demographics + "age", + "gender", + "is_adult", + "is_child", + # Income + "employment_income", + "self_employment_income", + "pension_income", + "private_pension_income", + "savings_interest_income", + "dividend_income", + "property_income", + "total_income", + "earned_income", + # Benefits + "universal_credit", + "child_benefit", + "pension_credit", + "income_support", + "working_tax_credit", + "child_tax_credit", + # Tax + "income_tax", + "national_insurance", + "net_income", + ], + "benunit": [ + # IDs and weights + "benunit_id", + "benunit_weight", + # Structure + "family_type", + "num_adults", + "num_children", + # Income and benefits + "benunit_total_income", + "benunit_net_income", + "universal_credit", + "child_benefit", + "working_tax_credit", + "child_tax_credit", + ], + "household": [ + # IDs and weights + "household_id", + "household_weight", + # Income measures + "household_net_income", + "hbai_household_net_income", + "equiv_hbai_household_net_income", + "household_market_income", + "household_gross_income", + # Benefits and tax + "household_benefits", + "household_tax", + # Housing + "rent", + "council_tax", + "housing_benefit", + ], + } data = { "person": pd.DataFrame(), @@ -172,6 +385,16 @@ def run(self, simulation: "Simulation") -> "Simulation": var, period=simulation.dataset.year, map_to=entity ).values + data["person"] = MicroDataFrame( + data["person"], weights="person_weight" + ) + data["benunit"] = MicroDataFrame( + data["benunit"], weights="benunit_weight" + ) + data["household"] = MicroDataFrame( + data["household"], weights="household_weight" + ) + simulation.output_dataset = PolicyEngineUKDataset( name=dataset.name, description=dataset.description, diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py new file mode 100644 index 00000000..57c1a0fe --- /dev/null +++ b/tests/test_aggregate.py @@ -0,0 +1,490 @@ +import pandas as pd +import tempfile +import os +from microdf import MicroDataFrame +from policyengine.core import * +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, + uk_latest, +) +from policyengine.outputs.aggregate import Aggregate, AggregateType + + +def test_aggregate_sum(): + """Test basic sum aggregation.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "test.h5") + + dataset = PolicyEngineUKDataset( + name="Test", + description="Test dataset", + filepath=filepath, + year=2024, + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), + ) + + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + output_dataset=dataset, + ) + + agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.SUM, + ) + agg.run() + + assert agg.result == 140000 + + +def test_aggregate_mean(): + """Test mean aggregation.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "test.h5") + + dataset = PolicyEngineUKDataset( + name="Test", + description="Test dataset", + filepath=filepath, + year=2024, + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), + ) + + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + output_dataset=dataset, + ) + + agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.MEAN, + ) + agg.run() + + assert abs(agg.result - 46666.67) < 1 + + +def test_aggregate_count(): + """Test count aggregation.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "test.h5") + + dataset = PolicyEngineUKDataset( + name="Test", + description="Test dataset", + filepath=filepath, + year=2024, + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), + ) + + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + output_dataset=dataset, + ) + + agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.COUNT, + ) + agg.run() + + assert agg.result == 3 + + +def test_aggregate_with_entity_mapping(): + """Test aggregation with entity mapping (person var at household level).""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "test.h5") + + dataset = PolicyEngineUKDataset( + name="Test", + description="Test dataset", + filepath=filepath, + year=2024, + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), + ) + + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + output_dataset=dataset, + ) + + # Aggregate person-level income at household level + agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.SUM, + entity="household", + ) + agg.run() + + # Should sum across all people mapped to households + assert agg.result == 140000 + + +def test_aggregate_with_filter(): + """Test aggregation with basic filter.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [50000, 30000, 60000, 45000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "test.h5") + + dataset = PolicyEngineUKDataset( + name="Test", + description="Test dataset", + filepath=filepath, + year=2024, + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), + ) + + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + output_dataset=dataset, + ) + + # Sum income for people age >= 30 + agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.SUM, + filter_variable="age", + filter_variable_geq=30, + ) + agg.run() + + # Should only include people aged 30, 40, and 35 + assert agg.result == 50000 + 60000 + 45000 + + +def test_aggregate_with_quantile_filter(): + """Test aggregation with quantile-based filter.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3, 4, 5], + "benunit_id": [1, 1, 2, 2, 3], + "household_id": [1, 1, 2, 2, 3], + "age": [20, 30, 40, 50, 60], + "employment_income": [10000, 20000, 30000, 40000, 50000], + "person_weight": [1.0, 1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2, 3], + "benunit_weight": [1.0, 1.0, 1.0], + } + ), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2, 3], + "household_weight": [1.0, 1.0, 1.0], + } + ), + weights="household_weight", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "test.h5") + + dataset = PolicyEngineUKDataset( + name="Test", + description="Test dataset", + filepath=filepath, + year=2024, + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), + ) + + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + output_dataset=dataset, + ) + + # Sum income for bottom 50% (by income) + agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.SUM, + filter_variable="employment_income", + filter_variable_leq=0.5, + filter_variable_describes_quantiles=True, + ) + agg.run() + + # Should include people with income <= median (30000) + assert agg.result == 10000 + 20000 + 30000 + + +def test_aggregate_invalid_variable(): + """Test that invalid variable names raise errors during run().""" + import pytest + + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1], + "benunit_id": [1], + "household_id": [1], + "age": [30], + "employment_income": [50000], + "person_weight": [1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1], + "benunit_weight": [1.0], + } + ), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1], + "household_weight": [1.0], + } + ), + weights="household_weight", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "test.h5") + + dataset = PolicyEngineUKDataset( + name="Test", + description="Test dataset", + filepath=filepath, + year=2024, + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), + ) + + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + output_dataset=dataset, + ) + + # Invalid variable name should raise error on run() + agg = Aggregate( + simulation=simulation, + variable="nonexistent_variable", + aggregate_type=AggregateType.SUM, + ) + with pytest.raises(StopIteration): + agg.run() + + # Invalid filter variable name should raise error on run() + agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.SUM, + filter_variable="nonexistent_filter", + ) + with pytest.raises(StopIteration): + agg.run() diff --git a/tests/test_entity_mapping.py b/tests/test_entity_mapping.py new file mode 100644 index 00000000..ffbecacd --- /dev/null +++ b/tests/test_entity_mapping.py @@ -0,0 +1,384 @@ +import pandas as pd +import pytest +from microdf import MicroDataFrame +from policyengine.tax_benefit_models.uk import UKYearData + + +def test_map_same_entity(): + """Test mapping from an entity to itself returns the same data.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({"benunit_id": [1, 2], "benunit_weight": [1.0, 1.0]}), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame({"household_id": [1, 2], "household_weight": [1.0, 1.0]}), + weights="household_weight", + ) + + data = UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ) + + # Test person -> person + result = data.map_to_entity("person", "person") + assert isinstance(result, MicroDataFrame) + assert len(result) == 3 + assert list(result["person_id"]) == [1, 2, 3] + + # Test benunit -> benunit + result = data.map_to_entity("benunit", "benunit") + assert isinstance(result, MicroDataFrame) + assert len(result) == 2 + assert list(result["benunit_id"]) == [1, 2] + + # Test household -> household + result = data.map_to_entity("household", "household") + assert isinstance(result, MicroDataFrame) + assert len(result) == 2 + assert list(result["household_id"]) == [1, 2] + + +def test_map_person_to_benunit(): + """Test mapping person-level data to benunit level.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({"benunit_id": [1, 2], "benunit_weight": [1.0, 1.0]}), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame({"household_id": [1, 2], "household_weight": [1.0, 1.0]}), + weights="household_weight", + ) + + data = UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ) + + result = data.map_to_entity("person", "benunit") + + # Should return a MicroDataFrame + assert isinstance(result, MicroDataFrame) + # Should have rows for each person + assert len(result) == 3 + # Should have benunit data merged in + assert "benunit_id" in result.columns + assert "person_id" in result.columns + + +def test_map_person_to_household(): + """Test mapping person-level data to household level.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({"benunit_id": [1, 2], "benunit_weight": [1.0, 1.0]}), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "rent": [1000, 800], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + data = UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ) + + result = data.map_to_entity("person", "household") + + # Should have rows for each person + assert len(result) == 3 + # Should have household data merged in + assert "household_id" in result.columns + assert "person_id" in result.columns + assert "rent" in result.columns + + +def test_map_benunit_to_person(): + """Test mapping benunit-level data to person level.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "total_benefit": [1000, 500], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame({"household_id": [1, 2], "household_weight": [1.0, 1.0]}), + weights="household_weight", + ) + + data = UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ) + + result = data.map_to_entity("benunit", "person") + + # Should have rows for each person + assert len(result) == 3 + # Should have benunit data merged in + assert "benunit_id" in result.columns + assert "person_id" in result.columns + assert "total_benefit" in result.columns + + +def test_map_benunit_to_household(): + """Test mapping benunit-level data to household level.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "total_benefit": [1000, 500], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame({"household_id": [1, 2], "household_weight": [1.0, 1.0]}), + weights="household_weight", + ) + + data = UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ) + + result = data.map_to_entity("benunit", "household") + + # Should have benunit and household data + assert "benunit_id" in result.columns + assert "household_id" in result.columns + assert "total_benefit" in result.columns + + +def test_map_household_to_person(): + """Test mapping household-level data to person level.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({"benunit_id": [1, 2], "benunit_weight": [1.0, 1.0]}), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "rent": [1000, 800], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + data = UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ) + + result = data.map_to_entity("household", "person") + + # Should have rows for each person + assert len(result) == 3 + # Should have household data merged in + assert "household_id" in result.columns + assert "person_id" in result.columns + assert "rent" in result.columns + + +def test_map_household_to_benunit(): + """Test mapping household-level data to benunit level.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({"benunit_id": [1, 2], "benunit_weight": [1.0, 1.0]}), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "rent": [1000, 800], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + data = UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ) + + result = data.map_to_entity("household", "benunit") + + # Should have benunit and household data + assert "benunit_id" in result.columns + assert "household_id" in result.columns + assert "rent" in result.columns + + +def test_map_with_column_selection(): + """Test mapping with specific column selection.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({"benunit_id": [1, 2], "benunit_weight": [1.0, 1.0]}), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame({"household_id": [1, 2], "household_weight": [1.0, 1.0]}), + weights="household_weight", + ) + + data = UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ) + + # Map only age to household + result = data.map_to_entity("person", "household", columns=["age"]) + + assert "age" in result.columns + assert "household_id" in result.columns + # income should not be included + assert "income" not in result.columns + + +def test_invalid_entity_names(): + """Test that invalid entity names raise ValueError.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1], + "benunit_id": [1], + "household_id": [1], + "person_weight": [1.0], + } + ), + weights="person_weight", + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({"benunit_id": [1], "benunit_weight": [1.0]}), + weights="benunit_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame({"household_id": [1], "household_weight": [1.0]}), + weights="household_weight", + ) + + data = UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ) + + with pytest.raises(ValueError, match="Invalid source entity"): + data.map_to_entity("invalid", "person") + + with pytest.raises(ValueError, match="Invalid target entity"): + data.map_to_entity("person", "invalid") diff --git a/tests/test_uk_dataset.py b/tests/test_uk_dataset.py index d4e2ed83..5e3f8e68 100644 --- a/tests/test_uk_dataset.py +++ b/tests/test_uk_dataset.py @@ -1,6 +1,7 @@ import pandas as pd import tempfile import os +from microdf import MicroDataFrame from policyengine.core import * from policyengine.tax_benefit_models.uk import ( PolicyEngineUKDataset, @@ -33,20 +34,40 @@ def test_uk_latest_instantiation(): def test_save_and_load_single_year(): """Test saving and loading a dataset with a single year.""" # Create sample data - person_df = pd.DataFrame( - { - "person_id": [1, 2, 3], - "age": [25, 30, 35], - "income": [30000, 45000, 60000], - } + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "age": [25, 30, 35], + "income": [30000, 45000, 60000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) - benunit_df = pd.DataFrame( - {"benunit_id": [1, 2], "size": [2, 1], "total_income": [75000, 60000]} + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "size": [2, 1], + "total_income": [75000, 60000], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", ) - household_df = pd.DataFrame( - {"household_id": [1], "num_people": [3], "rent": [1200]} + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1], + "num_people": [3], + "rent": [1200], + "household_weight": [1.0], + } + ), + weights="household_weight", ) # Create dataset @@ -77,6 +98,13 @@ def test_save_and_load_single_year(): # Verify data assert loaded.year == 2025 - pd.testing.assert_frame_equal(loaded.data.person, person_df) - pd.testing.assert_frame_equal(loaded.data.benunit, benunit_df) - pd.testing.assert_frame_equal(loaded.data.household, household_df) + # Convert to DataFrame for comparison (MicroDataFrame inherits from DataFrame) + pd.testing.assert_frame_equal( + pd.DataFrame(loaded.data.person), pd.DataFrame(person_df) + ) + pd.testing.assert_frame_equal( + pd.DataFrame(loaded.data.benunit), pd.DataFrame(benunit_df) + ) + pd.testing.assert_frame_equal( + pd.DataFrame(loaded.data.household), pd.DataFrame(household_df) + ) From 48132068faa9641301bd24f5cb512fbdee6e519f Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 12 Nov 2025 18:02:02 +0000 Subject: [PATCH 19/35] Add change-aggregate --- examples/policy_change.py | 313 ++++++++ src/policyengine/outputs/__init__.py | 9 + src/policyengine/outputs/change_aggregate.py | 125 ++++ src/policyengine/tax_benefit_models/uk.py | 2 - tests/test_change_aggregate.py | 721 +++++++++++++++++++ 5 files changed, 1168 insertions(+), 2 deletions(-) create mode 100644 examples/policy_change.py create mode 100644 src/policyengine/outputs/__init__.py create mode 100644 src/policyengine/outputs/change_aggregate.py create mode 100644 tests/test_change_aggregate.py diff --git a/examples/policy_change.py b/examples/policy_change.py new file mode 100644 index 00000000..527c0bbe --- /dev/null +++ b/examples/policy_change.py @@ -0,0 +1,313 @@ +"""Example: Analyze policy change impacts using ChangeAggregate. + +This script demonstrates: +1. Creating baseline and reform datasets with different income distributions +2. Using ChangeAggregate to analyze winners, losers, and impact sizes +3. Filtering by absolute and relative change thresholds +4. Visualising results with Plotly + +Run: python examples/policy_change.py +""" +import numpy as np +import pandas as pd +from microdf import MicroDataFrame +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from policyengine.core import Simulation +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, + uk_latest, +) +from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType + +# Create baseline dataset with random incomes +np.random.seed(42) +n_people = 1000 + +baseline_person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": range(1, n_people + 1), + "benunit_id": range(1, n_people + 1), + "household_id": range(1, n_people + 1), + "age": np.random.randint(18, 70, n_people), + "employment_income": np.random.exponential(35000, n_people), + "person_weight": np.ones(n_people), + }), + weights="person_weight" +) + +benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": range(1, n_people + 1), + "benunit_weight": np.ones(n_people), + }), + weights="benunit_weight" +) + +household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": range(1, n_people + 1), + "household_weight": np.ones(n_people), + }), + weights="household_weight" +) + +# Create baseline dataset +baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline scenario", + filepath="./baseline_data.h5", + year=2024, + data=UKYearData(person=baseline_person_df, benunit=benunit_df, household=household_df), +) + +# Create reform dataset - progressive income boost +# Low earners get 10% boost, high earners get 5% boost, middle gets 7.5% +baseline_incomes = baseline_person_df["employment_income"].values +reform_incomes = [] +for income in baseline_incomes: + if income < 25000: + boost = 0.10 # 10% for low earners + elif income < 50000: + boost = 0.075 # 7.5% for middle earners + else: + boost = 0.05 # 5% for high earners + reform_incomes.append(income * (1 + boost)) + +reform_person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": range(1, n_people + 1), + "benunit_id": range(1, n_people + 1), + "household_id": range(1, n_people + 1), + "age": baseline_person_df["age"].values, + "employment_income": reform_incomes, + "person_weight": np.ones(n_people), + }), + weights="person_weight" +) + +reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Progressive income boost", + filepath="./reform_data.h5", + year=2024, + data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), +) + +# Create simulations +baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, +) + +reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, +) + +# Analysis 1: Overall winners/losers +winners = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + change_geq=1, +) +winners.run() + +losers = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + change_leq=-1, +) +losers.run() + +no_change = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + change_eq=0, +) +no_change.run() + +# Analysis 2: Total gains and losses +total_gain = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.SUM, + change_geq=0, +) +total_gain.run() + +total_loss = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.SUM, + change_leq=0, +) +total_loss.run() + +# Analysis 3: Distribution of gains by size +gain_bands = [ + ("£0-500", 0, 500), + ("£500-1k", 500, 1000), + ("£1k-2k", 1000, 2000), + ("£2k-3k", 2000, 3000), + ("£3k-5k", 3000, 5000), + ("£5k+", 5000, None), +] + +gain_counts = [] +for label, lower, upper in gain_bands: + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + change_geq=lower, + change_leq=upper, + ) + agg.run() + gain_counts.append(agg.result) + +# Analysis 4: Impact by age group +age_groups = [ + ("18-30", 18, 30), + ("31-45", 31, 45), + ("46-60", 46, 60), + ("61+", 61, 150), +] + +age_group_labels = [] +age_group_winners = [] +age_group_avg_gain = [] + +for label, min_age, max_age in age_groups: + # Count winners in age group + count_agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + change_geq=1, + filter_variable="age", + filter_variable_geq=min_age, + filter_variable_leq=max_age, + ) + count_agg.run() + + # Average gain in age group + mean_agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.MEAN, + filter_variable="age", + filter_variable_geq=min_age, + filter_variable_leq=max_age, + ) + mean_agg.run() + + age_group_labels.append(label) + age_group_winners.append(count_agg.result) + age_group_avg_gain.append(mean_agg.result) + +# Analysis 5: Large winners (gaining more than 10%) +large_winners = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + relative_change_geq=0.10, +) +large_winners.run() + +# Create visualisations +fig = make_subplots( + rows=2, cols=2, + subplot_titles=( + "Winners vs losers", + "Distribution of gains", + "Winners by age group", + "Average gain by age group" + ), + specs=[[{"type": "bar"}, {"type": "bar"}], + [{"type": "bar"}, {"type": "bar"}]] +) + +fig.add_trace( + go.Bar( + x=["Winners", "No change", "Losers"], + y=[winners.result, no_change.result, losers.result], + marker_color=["green", "gray", "red"] + ), + row=1, col=1 +) + +fig.add_trace( + go.Bar( + x=[label for label, _, _ in gain_bands], + y=gain_counts, + marker_color="lightblue" + ), + row=1, col=2 +) + +fig.add_trace( + go.Bar( + x=age_group_labels, + y=age_group_winners, + marker_color="lightgreen" + ), + row=2, col=1 +) + +fig.add_trace( + go.Bar( + x=age_group_labels, + y=age_group_avg_gain, + marker_color="orange" + ), + row=2, col=2 +) + +fig.update_xaxes(title_text="Category", row=1, col=1) +fig.update_xaxes(title_text="Gain amount", row=1, col=2) +fig.update_xaxes(title_text="Age group", row=2, col=1) +fig.update_xaxes(title_text="Age group", row=2, col=2) + +fig.update_yaxes(title_text="Number of people", row=1, col=1) +fig.update_yaxes(title_text="Number of people", row=1, col=2) +fig.update_yaxes(title_text="Number of winners", row=2, col=1) +fig.update_yaxes(title_text="Average gain (£)", row=2, col=2) + +fig.update_layout( + title_text="Policy change impact analysis", + showlegend=False, + height=800, +) + +# Print summary statistics +print("=" * 60) +print("Policy change impact summary") +print("=" * 60) +print(f"\nOverall impact:") +print(f" Winners: {winners.result:,.0f} people") +print(f" Losers: {losers.result:,.0f} people") +print(f" No change: {no_change.result:,.0f} people") +print(f"\nFinancial impact:") +print(f" Total gains: £{total_gain.result:,.0f}") +print(f" Total losses: £{total_loss.result:,.0f}") +print(f" Net change: £{total_gain.result + total_loss.result:,.0f}") +print(f"\nLarge winners (>10% gain): {large_winners.result:,.0f} people") +print("=" * 60) + +fig.show() diff --git a/src/policyengine/outputs/__init__.py b/src/policyengine/outputs/__init__.py new file mode 100644 index 00000000..8b255697 --- /dev/null +++ b/src/policyengine/outputs/__init__.py @@ -0,0 +1,9 @@ +from policyengine.outputs.aggregate import Aggregate, AggregateType +from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType + +__all__ = [ + "Aggregate", + "AggregateType", + "ChangeAggregate", + "ChangeAggregateType", +] diff --git a/src/policyengine/outputs/change_aggregate.py b/src/policyengine/outputs/change_aggregate.py new file mode 100644 index 00000000..57a01479 --- /dev/null +++ b/src/policyengine/outputs/change_aggregate.py @@ -0,0 +1,125 @@ +from pydantic import BaseModel +from policyengine.core import * +from enum import Enum +from typing import Any + + +class ChangeAggregateType(str, Enum): + COUNT = "count" + SUM = "sum" + MEAN = "mean" + + +class ChangeAggregate(BaseModel): + baseline_simulation: Simulation + reform_simulation: Simulation + variable: str + aggregate_type: ChangeAggregateType + entity: str | None = None + + # Filter by absolute change + change_geq: float | None = None # Change >= value (e.g., gain >= 500) + change_leq: float | None = None # Change <= value (e.g., loss <= -500) + change_eq: float | None = None # Change == value + + # Filter by relative change (as decimal, e.g., 0.05 = 5%) + relative_change_geq: float | None = None # Relative change >= value + relative_change_leq: float | None = None # Relative change <= value + relative_change_eq: float | None = None # Relative change == value + + # Filter by another variable (e.g., only count people with age >= 30) + filter_variable: str | None = None + filter_variable_eq: Any | None = None + filter_variable_leq: Any | None = None + filter_variable_geq: Any | None = None + + result: Any | None = None + + def run(self): + # Get variable object + var_obj = next( + v for v in self.baseline_simulation.tax_benefit_model_version.variables + if v.name == self.variable + ) + + # Get the target entity data + target_entity = self.entity or var_obj.entity + baseline_data = getattr(self.baseline_simulation.output_dataset.data, target_entity) + reform_data = getattr(self.reform_simulation.output_dataset.data, target_entity) + + # Map variable to target entity if needed + if var_obj.entity != target_entity: + baseline_mapped = self.baseline_simulation.output_dataset.data.map_to_entity( + var_obj.entity, target_entity + ) + baseline_series = baseline_mapped[self.variable] + + reform_mapped = self.reform_simulation.output_dataset.data.map_to_entity( + var_obj.entity, target_entity + ) + reform_series = reform_mapped[self.variable] + else: + baseline_series = baseline_data[self.variable] + reform_series = reform_data[self.variable] + + # Calculate change (reform - baseline) + change_series = reform_series - baseline_series + + # Calculate relative change (handling division by zero) + # Where baseline is 0, relative change is undefined; we'll mask these out if relative filters are used + import numpy as np + with np.errstate(divide='ignore', invalid='ignore'): + relative_change_series = change_series / baseline_series + relative_change_series = relative_change_series.replace([np.inf, -np.inf], np.nan) + + # Start with all rows + mask = baseline_series.notna() + + # Apply absolute change filters + if self.change_eq is not None: + mask &= (change_series == self.change_eq) + if self.change_leq is not None: + mask &= (change_series <= self.change_leq) + if self.change_geq is not None: + mask &= (change_series >= self.change_geq) + + # Apply relative change filters + if self.relative_change_eq is not None: + mask &= (relative_change_series == self.relative_change_eq) + if self.relative_change_leq is not None: + mask &= (relative_change_series <= self.relative_change_leq) + if self.relative_change_geq is not None: + mask &= (relative_change_series >= self.relative_change_geq) + + # Apply filter_variable filters + if self.filter_variable is not None: + filter_var_obj = next( + v for v in self.baseline_simulation.tax_benefit_model_version.variables + if v.name == self.filter_variable + ) + + if filter_var_obj.entity != target_entity: + filter_mapped = self.baseline_simulation.output_dataset.data.map_to_entity( + filter_var_obj.entity, target_entity + ) + filter_series = filter_mapped[self.filter_variable] + else: + filter_series = baseline_data[self.filter_variable] + + if self.filter_variable_eq is not None: + mask &= (filter_series == self.filter_variable_eq) + if self.filter_variable_leq is not None: + mask &= (filter_series <= self.filter_variable_leq) + if self.filter_variable_geq is not None: + mask &= (filter_series >= self.filter_variable_geq) + + # Apply mask to get filtered data + filtered_change = change_series[mask] + + # Aggregate + if self.aggregate_type == ChangeAggregateType.COUNT: + self.result = filtered_change.count() + elif self.aggregate_type == ChangeAggregateType.SUM: + self.result = filtered_change.sum() + elif self.aggregate_type == ChangeAggregateType.MEAN: + self.result = filtered_change.mean() diff --git a/src/policyengine/tax_benefit_models/uk.py b/src/policyengine/tax_benefit_models/uk.py index b6e14aaf..1a935b5d 100644 --- a/src/policyengine/tax_benefit_models/uk.py +++ b/src/policyengine/tax_benefit_models/uk.py @@ -335,7 +335,6 @@ def run(self, simulation: "Simulation") -> "Simulation": # Tax "income_tax", "national_insurance", - "net_income", ], "benunit": [ # IDs and weights @@ -369,7 +368,6 @@ def run(self, simulation: "Simulation") -> "Simulation": # Housing "rent", "council_tax", - "housing_benefit", ], } diff --git a/tests/test_change_aggregate.py b/tests/test_change_aggregate.py new file mode 100644 index 00000000..de878996 --- /dev/null +++ b/tests/test_change_aggregate.py @@ -0,0 +1,721 @@ +import pandas as pd +import tempfile +import os +from microdf import MicroDataFrame +from policyengine.core import * +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, + uk_latest, +) +from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType + + +def test_change_aggregate_count(): + """Test counting people with any change.""" + person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [50000, 30000, 60000, 40000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + }), + weights="benunit_weight" + ) + + household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + }), + weights="household_weight" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + baseline_filepath = os.path.join(tmpdir, "baseline.h5") + reform_filepath = os.path.join(tmpdir, "reform.h5") + + baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline dataset", + filepath=baseline_filepath, + year=2024, + data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + ) + + # Reform: increase everyone's income by 1000 + reform_person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [51000, 31000, 61000, 41000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Reform dataset", + filepath=reform_filepath, + year=2024, + data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + ) + + baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, + ) + + reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, + ) + + # Count people with any change (all 4 should have changed) + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + ) + agg.run() + + assert agg.result == 4 + + +def test_change_aggregate_with_absolute_filter(): + """Test filtering by absolute change amount.""" + person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [50000, 30000, 60000, 40000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + }), + weights="benunit_weight" + ) + + household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + }), + weights="household_weight" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + baseline_filepath = os.path.join(tmpdir, "baseline.h5") + reform_filepath = os.path.join(tmpdir, "reform.h5") + + baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline dataset", + filepath=baseline_filepath, + year=2024, + data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + ) + + # Reform: different gains for different people + reform_person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [52000, 30500, 61500, 40200], # Gains: 2000, 500, 1500, 200 + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Reform dataset", + filepath=reform_filepath, + year=2024, + data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + ) + + baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, + ) + + reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, + ) + + # Count people who gain at least 1000 + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + change_geq=1000, + ) + agg.run() + + assert agg.result == 2 # People 1 and 3 + + +def test_change_aggregate_with_loss_filter(): + """Test filtering for losses (negative changes).""" + person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [50000, 30000, 60000, 40000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + }), + weights="benunit_weight" + ) + + household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + }), + weights="household_weight" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + baseline_filepath = os.path.join(tmpdir, "baseline.h5") + reform_filepath = os.path.join(tmpdir, "reform.h5") + + baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline dataset", + filepath=baseline_filepath, + year=2024, + data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + ) + + # Reform: some people lose money + reform_person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [49000, 29000, 60500, 39000], # Changes: -1000, -1000, 500, -1000 + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Reform dataset", + filepath=reform_filepath, + year=2024, + data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + ) + + baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, + ) + + reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, + ) + + # Count people who lose at least 500 (change <= -500) + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + change_leq=-500, + ) + agg.run() + + assert agg.result == 3 # People 1, 2, and 4 + + +def test_change_aggregate_with_relative_filter(): + """Test filtering by relative (percentage) change.""" + person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [50000, 20000, 60000, 40000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + }), + weights="benunit_weight" + ) + + household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + }), + weights="household_weight" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + baseline_filepath = os.path.join(tmpdir, "baseline.h5") + reform_filepath = os.path.join(tmpdir, "reform.h5") + + baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline dataset", + filepath=baseline_filepath, + year=2024, + data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + ) + + # Reform: different percentage gains + reform_person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + # Gains: 5000 (10%), 2000 (10%), 3000 (5%), 1000 (2.5%) + "employment_income": [55000, 22000, 63000, 41000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Reform dataset", + filepath=reform_filepath, + year=2024, + data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + ) + + baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, + ) + + reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, + ) + + # Count people who gain at least 8% (0.08 relative change) + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + relative_change_geq=0.08, + ) + agg.run() + + assert agg.result == 2 # People 1 and 2 (both 10%) + + +def test_change_aggregate_sum(): + """Test summing changes.""" + person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + }), + weights="benunit_weight" + ) + + household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + }), + weights="household_weight" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + baseline_filepath = os.path.join(tmpdir, "baseline.h5") + reform_filepath = os.path.join(tmpdir, "reform.h5") + + baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline dataset", + filepath=baseline_filepath, + year=2024, + data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + ) + + # Reform: everyone gains 1000 + reform_person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [51000, 31000, 61000], + "person_weight": [1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Reform dataset", + filepath=reform_filepath, + year=2024, + data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + ) + + baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, + ) + + reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, + ) + + # Sum all changes + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.SUM, + ) + agg.run() + + assert agg.result == 3000 + + +def test_change_aggregate_mean(): + """Test mean change.""" + person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + }), + weights="benunit_weight" + ) + + household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + }), + weights="household_weight" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + baseline_filepath = os.path.join(tmpdir, "baseline.h5") + reform_filepath = os.path.join(tmpdir, "reform.h5") + + baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline dataset", + filepath=baseline_filepath, + year=2024, + data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + ) + + # Reform: different gains + reform_person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [51000, 32000, 63000], # Gains: 1000, 2000, 3000 + "person_weight": [1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Reform dataset", + filepath=reform_filepath, + year=2024, + data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + ) + + baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, + ) + + reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, + ) + + # Mean change + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.MEAN, + ) + agg.run() + + assert agg.result == 2000 + + +def test_change_aggregate_with_filter_variable(): + """Test filtering by another variable (e.g., only adults).""" + person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 15], # Person 4 is a child + "employment_income": [50000, 30000, 60000, 5000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + }), + weights="benunit_weight" + ) + + household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + }), + weights="household_weight" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + baseline_filepath = os.path.join(tmpdir, "baseline.h5") + reform_filepath = os.path.join(tmpdir, "reform.h5") + + baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline dataset", + filepath=baseline_filepath, + year=2024, + data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + ) + + # Reform: everyone gains 1000 + reform_person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 15], + "employment_income": [51000, 31000, 61000, 6000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Reform dataset", + filepath=reform_filepath, + year=2024, + data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + ) + + baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, + ) + + reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, + ) + + # Count adults (age >= 18) who gain money + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + change_geq=1, + filter_variable="age", + filter_variable_geq=18, + ) + agg.run() + + assert agg.result == 3 # Exclude person 4 (age 15) + + +def test_change_aggregate_combined_filters(): + """Test combining multiple filter types.""" + person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4, 5], + "benunit_id": [1, 1, 2, 2, 3], + "household_id": [1, 1, 2, 2, 3], + "age": [30, 25, 40, 35, 45], + "employment_income": [50000, 20000, 60000, 40000, 80000], + "person_weight": [1.0, 1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [1, 2, 3], + "benunit_weight": [1.0, 1.0, 1.0], + }), + weights="benunit_weight" + ) + + household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [1, 2, 3], + "household_weight": [1.0, 1.0, 1.0], + }), + weights="household_weight" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + baseline_filepath = os.path.join(tmpdir, "baseline.h5") + reform_filepath = os.path.join(tmpdir, "reform.h5") + + baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline dataset", + filepath=baseline_filepath, + year=2024, + data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + ) + + # Reform: varying gains + reform_person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [1, 2, 3, 4, 5], + "benunit_id": [1, 1, 2, 2, 3], + "household_id": [1, 1, 2, 2, 3], + "age": [30, 25, 40, 35, 45], + # Changes: 10000 (20%), 2000 (10%), 3000 (5%), 800 (2%), 4000 (5%) + "employment_income": [60000, 22000, 63000, 40800, 84000], + "person_weight": [1.0, 1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" + ) + + reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Reform dataset", + filepath=reform_filepath, + year=2024, + data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + ) + + baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, + ) + + reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, + ) + + # Count people age >= 30 who gain at least 2000 and at least 5% relative gain + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + change_geq=2000, + relative_change_geq=0.05, + filter_variable="age", + filter_variable_geq=30, + ) + agg.run() + + # Should include: person 1 (10000/20%, age 30), person 3 (3000/5%, age 40), person 5 (4000/5%, age 45) + # Should exclude: person 2 (age 25), person 4 (only 800 gain) + assert agg.result == 3 From 5604fb688cb6d24a88eb5b2ed24a1c0501af6b10 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Thu, 13 Nov 2025 14:12:02 +0000 Subject: [PATCH 20/35] Parametric reform handling, plus UK datasets --- .gitignore | 1 + src/policyengine/tax_benefit_models/uk.py | 47 +++++++++++++++++--- src/policyengine/utils/parametric_reforms.py | 26 +++++++++++ 3 files changed, 69 insertions(+), 5 deletions(-) create mode 100644 src/policyengine/utils/parametric_reforms.py diff --git a/.gitignore b/.gitignore index 69dfad02..57a0fc21 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ *.ipynb _build/ .env +**/.DS_Store \ No newline at end of file diff --git a/src/policyengine/tax_benefit_models/uk.py b/src/policyengine/tax_benefit_models/uk.py index 1a935b5d..7ec9fa00 100644 --- a/src/policyengine/tax_benefit_models/uk.py +++ b/src/policyengine/tax_benefit_models/uk.py @@ -150,7 +150,9 @@ def __init__(self, **kwargs: dict): def save(self) -> None: """Save dataset to HDF5 file.""" - filepath = self.filepath + filepath = Path(self.filepath) + if not filepath.parent.exists(): + filepath.parent.mkdir(parents=True, exist_ok=True) with pd.HDFStore(filepath, mode="w") as store: store["person"] = pd.DataFrame(self.data.person) store["benunit"] = pd.DataFrame(self.data.benunit) @@ -273,6 +275,7 @@ def __init__(self, **kwargs: dict): def run(self, simulation: "Simulation") -> "Simulation": from policyengine_uk import Microsimulation from policyengine_uk.data import UKSingleYearDataset + from policyengine.utils.parametric_reforms import simulation_modifier_from_parameter_values assert isinstance(simulation.dataset, PolicyEngineUKDataset) @@ -291,12 +294,22 @@ def run(self, simulation: "Simulation") -> "Simulation": and simulation.policy.simulation_modifier is not None ): simulation.policy.simulation_modifier(microsim) + elif simulation.policy: + modifier = simulation_modifier_from_parameter_values( + simulation.policy.parameter_values + ) + modifier(microsim) if ( simulation.dynamic and simulation.dynamic.simulation_modifier is not None ): simulation.dynamic.simulation_modifier(microsim) + elif simulation.dynamic: + modifier = simulation_modifier_from_parameter_values( + simulation.dynamic.parameter_values + ) + modifier(microsim) # Allow custom variable selection, or use defaults if simulation.variables is not None: @@ -314,6 +327,7 @@ def run(self, simulation: "Simulation") -> "Simulation": "age", "gender", "is_adult", + "is_SP_age", "is_child", # Income "employment_income", @@ -342,11 +356,7 @@ def run(self, simulation: "Simulation") -> "Simulation": "benunit_weight", # Structure "family_type", - "num_adults", - "num_children", # Income and benefits - "benunit_total_income", - "benunit_net_income", "universal_credit", "child_benefit", "working_tax_credit", @@ -365,9 +375,11 @@ def run(self, simulation: "Simulation") -> "Simulation": # Benefits and tax "household_benefits", "household_tax", + "vat", # Housing "rent", "council_tax", + "tenure_type", ], } @@ -411,6 +423,31 @@ def run(self, simulation: "Simulation") -> "Simulation": simulation.output_dataset.save() +def create_datasets( + datasets: list[str] = [ + "hf://policyengine/policyengine-uk-data/frs_2023_24.h5", + "hf://policyengine/policyengine-uk-data/enhanced_frs_2023_24.h5", + ], + years: list[int] = [2026, 2027, 2028, 2029, 2030], +) -> None: + for dataset in datasets: + from policyengine_uk import Microsimulation + sim = Microsimulation(dataset=dataset) + for year in years: + year_dataset = sim.dataset[year] + uk_dataset = PolicyEngineUKDataset( + name=f"{dataset}-year-{year}", + description=f"UK Dataset for year {year} based on {dataset}", + filepath=f"./data/{Path(dataset).stem}_year_{year}.h5", + year=year, + data=UKYearData( + person=MicroDataFrame(year_dataset.person), + benunit=MicroDataFrame(year_dataset.benunit), + household=MicroDataFrame(year_dataset.household), + ), + ) + uk_dataset.save() + # Rebuild models to resolve forward references PolicyEngineUKDataset.model_rebuild() diff --git a/src/policyengine/utils/parametric_reforms.py b/src/policyengine/utils/parametric_reforms.py new file mode 100644 index 00000000..2c96b039 --- /dev/null +++ b/src/policyengine/utils/parametric_reforms.py @@ -0,0 +1,26 @@ +from policyengine.core import ParameterValue +from typing import Callable + + +def simulation_modifier_from_parameter_values(parameter_values: list[ParameterValue]) -> Callable: + """ + Create a simulation modifier function that applies the given parameter values to a simulation. + + Args: + parameter_values (list[ParameterValue]): List of ParameterValue objects to apply. + + Returns: + Callable: A function that takes a Simulation object and applies the parameter values. + """ + + def modifier(simulation): + for pv in parameter_values: + p = simulation.tax_benefit_system.parameters.get_child(pv.parameter.name) + p.update( + value=pv.value, + start=pv.start_date.strftime("%Y-%m-%d"), + stop=pv.stop_date.strftime("%Y-%m-%d") if pv.stop_date else None, + ) + return simulation + + return modifier \ No newline at end of file From 074283178d1561010c11a5ab396cf6ce8650fb3e Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Thu, 13 Nov 2025 15:54:32 +0000 Subject: [PATCH 21/35] Add more analysis functionality --- docs/index.md | 1 - examples/income_bands.py | 297 +++++----- examples/policy_change.py | 541 +++++++++--------- src/policyengine/outputs/__init__.py | 5 + src/policyengine/outputs/aggregate.py | 22 +- src/policyengine/outputs/base.py | 12 + src/policyengine/outputs/change_aggregate.py | 46 +- src/policyengine/outputs/decile_impact.py | 122 ++++ src/policyengine/tax_benefit_models/uk.py | 467 +-------------- .../tax_benefit_models/uk/__init__.py | 18 + .../tax_benefit_models/uk/analysis.py | 79 +++ .../tax_benefit_models/uk/datasets.py | 238 ++++++++ .../tax_benefit_models/uk/model.py | 255 +++++++++ .../tax_benefit_models/uk/outputs.py | 103 ++++ src/policyengine/utils/parametric_reforms.py | 7 +- 15 files changed, 1336 insertions(+), 877 deletions(-) create mode 100644 src/policyengine/outputs/base.py create mode 100644 src/policyengine/outputs/decile_impact.py create mode 100644 src/policyengine/tax_benefit_models/uk/__init__.py create mode 100644 src/policyengine/tax_benefit_models/uk/analysis.py create mode 100644 src/policyengine/tax_benefit_models/uk/datasets.py create mode 100644 src/policyengine/tax_benefit_models/uk/model.py create mode 100644 src/policyengine/tax_benefit_models/uk/outputs.py diff --git a/docs/index.md b/docs/index.md index 6e92c81c..dd467d12 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,7 +4,6 @@ This package aims to simplify and productionise the use of PolicyEngine's tax-be We do this by: * Standardising around a set of core types that let us do policy analysis in an object-oriented way -* Provide a nice clean interface to put instances of these types in a database * Exemplifying this behaviour by using this package in all PolicyEngine's production applications, and analyses In this documentation, we'll walk through the core concepts/types that this package makes available, and how you can use them to run policy analyses at scale. diff --git a/examples/income_bands.py b/examples/income_bands.py index bd665d1f..cd5819bd 100644 --- a/examples/income_bands.py +++ b/examples/income_bands.py @@ -1,167 +1,186 @@ -"""Example: Calculate total employment income by income band. +"""Example: Calculate net income and tax by income decile using representative microdata. This script demonstrates: -1. Creating a dataset with randomly sampled incomes (exponential distribution) -2. Using Aggregate to calculate statistics within income bands -3. Filtering with geq/leq constraints +1. Using representative household microdata +2. Running a full microsimulation to calculate income tax and net income +3. Using Aggregate to calculate statistics within income deciles using quantile filters 4. Visualising results with Plotly Run: python examples/income_bands.py """ -import numpy as np -import pandas as pd -from microdf import MicroDataFrame +from pathlib import Path import plotly.graph_objects as go from plotly.subplots import make_subplots from policyengine.core import Simulation from policyengine.tax_benefit_models.uk import ( PolicyEngineUKDataset, - UKYearData, uk_latest, ) from policyengine.outputs.aggregate import Aggregate, AggregateType -# Create sample data with random incomes (simplified - no simulation needed) -np.random.seed(42) -n_people = 1000 - -person_df = MicroDataFrame( - pd.DataFrame( - { - "person_id": range(1, n_people + 1), - "benunit_id": range(1, n_people + 1), - "household_id": range(1, n_people + 1), - "age": np.random.randint(18, 70, n_people), - "employment_income": np.random.exponential(35000, n_people), - "person_weight": np.ones(n_people), - } - ), - weights="person_weight", -) -benunit_df = MicroDataFrame( - pd.DataFrame( - { - "benunit_id": range(1, n_people + 1), - "benunit_weight": np.ones(n_people), - } - ), - weights="benunit_weight", -) +def load_representative_data(year: int = 2026) -> PolicyEngineUKDataset: + """Load representative household microdata for a given year.""" + dataset_path = Path(f"./data/enhanced_frs_2023_24_year_{year}.h5") -household_df = MicroDataFrame( - pd.DataFrame( - { - "household_id": range(1, n_people + 1), - "household_weight": np.ones(n_people), - } - ), - weights="household_weight", -) + if not dataset_path.exists(): + raise FileNotFoundError( + f"Dataset not found at {dataset_path}. " + "Run create_datasets() from policyengine.tax_benefit_models.uk first." + ) -# Create dataset (this serves as our output dataset) -dataset = PolicyEngineUKDataset( - name="Sample Dataset", - description="Random sample for testing", - filepath="./sample_data.h5", - year=2024, - data=UKYearData( - person=person_df, benunit=benunit_df, household=household_df - ), -) + dataset = PolicyEngineUKDataset( + name=f"Enhanced FRS {year}", + description=f"Representative household microdata for {year}", + filepath=str(dataset_path), + year=year, + ) + dataset.load() + return dataset -# Create simulation with dataset as output -simulation = Simulation( - dataset=dataset, - tax_benefit_model_version=uk_latest, - output_dataset=dataset, -) -# Calculate total income by 10k bands -bands = [] -totals = [] -counts = [] - -for lower in range(0, 100000, 10000): - upper = lower + 10000 - - agg = Aggregate( - simulation=simulation, - variable="employment_income", - aggregate_type=AggregateType.SUM, - filter_variable="employment_income", - filter_variable_geq=lower, - filter_variable_leq=upper, +def run_simulation(dataset: PolicyEngineUKDataset) -> Simulation: + """Run a microsimulation on the dataset.""" + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, ) - agg.run() - - count_agg = Aggregate( - simulation=simulation, - variable="employment_income", - aggregate_type=AggregateType.COUNT, - filter_variable="employment_income", - filter_variable_geq=lower, - filter_variable_leq=upper, + simulation.run() + return simulation + + +def calculate_income_decile_statistics(simulation: Simulation) -> dict: + """Calculate total income, tax, and population by income deciles.""" + deciles = [] + net_incomes = [] + taxes = [] + counts = [] + + for decile in range(1, 11): + net_income_agg = Aggregate( + simulation=simulation, + variable="household_net_income", + aggregate_type=AggregateType.SUM, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile, + ) + net_income_agg.run() + + tax_agg = Aggregate( + simulation=simulation, + variable="household_tax", + aggregate_type=AggregateType.SUM, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile, + ) + tax_agg.run() + + count_agg = Aggregate( + simulation=simulation, + variable="household_net_income", + aggregate_type=AggregateType.COUNT, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile, + ) + count_agg.run() + + deciles.append(f"Decile {decile}") + net_incomes.append(net_income_agg.result / 1e9) # Convert to billions + taxes.append(tax_agg.result / 1e9) + counts.append(count_agg.result / 1e6) # Convert to millions + + return { + "deciles": deciles, + "net_incomes": net_incomes, + "taxes": taxes, + "counts": counts, + } + + +def visualise_results(results: dict) -> None: + """Create visualisations of income decile statistics.""" + fig = make_subplots( + rows=1, + cols=3, + subplot_titles=( + "Net income by decile (£bn)", + "Tax by decile (£bn)", + "Households by decile (millions)", + ), + specs=[[{"type": "bar"}, {"type": "bar"}, {"type": "bar"}]], ) - count_agg.run() - - bands.append(f"£{lower // 1000}k-£{upper // 1000}k") - totals.append(agg.result) - counts.append(count_agg.result) - -# Calculate 100k+ band -agg = Aggregate( - simulation=simulation, - variable="employment_income", - aggregate_type=AggregateType.SUM, - filter_variable="employment_income", - filter_variable_geq=100000, -) -agg.run() - -count_agg = Aggregate( - simulation=simulation, - variable="employment_income", - aggregate_type=AggregateType.COUNT, - filter_variable="employment_income", - filter_variable_geq=100000, -) -count_agg.run() - -bands.append("£100k+") -totals.append(agg.result) -counts.append(count_agg.result) - -# Create chart -fig = make_subplots( - rows=1, - cols=2, - subplot_titles=("Total income by band", "Population by band"), - specs=[[{"type": "bar"}, {"type": "bar"}]], -) -fig.add_trace( - go.Bar(x=bands, y=totals, name="Total income", marker_color="lightblue"), - row=1, - col=1, -) + fig.add_trace( + go.Bar( + x=results["deciles"], + y=results["net_incomes"], + marker_color="lightblue", + ), + row=1, + col=1, + ) -fig.add_trace( - go.Bar(x=bands, y=counts, name="Count", marker_color="lightgreen"), - row=1, - col=2, -) + fig.add_trace( + go.Bar( + x=results["deciles"], + y=results["taxes"], + marker_color="lightcoral", + ), + row=1, + col=2, + ) -fig.update_xaxes(title_text="Income band", row=1, col=1) -fig.update_xaxes(title_text="Income band", row=1, col=2) -fig.update_yaxes(title_text="Total income (£)", row=1, col=1) -fig.update_yaxes(title_text="Number of people", row=1, col=2) + fig.add_trace( + go.Bar( + x=results["deciles"], + y=results["counts"], + marker_color="lightgreen", + ), + row=1, + col=3, + ) + + fig.update_xaxes(title_text="Income decile", row=1, col=1) + fig.update_xaxes(title_text="Income decile", row=1, col=2) + fig.update_xaxes(title_text="Income decile", row=1, col=3) + + fig.update_layout( + title_text="Household income and tax distribution", + showlegend=False, + height=400, + ) + + fig.show() + + +def main(): + """Main execution function.""" + print("Loading representative household data...") + dataset = load_representative_data(year=2026) + + print("Running microsimulation...") + simulation = run_simulation(dataset) + + print("Calculating statistics by income decile...") + results = calculate_income_decile_statistics(simulation) + + print("\nResults summary:") + total_net_income = sum(results["net_incomes"]) + total_tax = sum(results["taxes"]) + total_households = sum(results["counts"]) + + print(f"Total net income: £{total_net_income:.1f}bn") + print(f"Total tax: £{total_tax:.1f}bn") + print(f"Total households: {total_households:.1f}m") + print(f"Average effective tax rate: {total_tax / (total_net_income + total_tax) * 100:.1f}%") + + print("\nGenerating visualisations...") + visualise_results(results) -fig.update_layout( - title_text="Employment income distribution", - showlegend=False, - height=400, -) -fig.show() +if __name__ == "__main__": + main() diff --git a/examples/policy_change.py b/examples/policy_change.py index 527c0bbe..243515f7 100644 --- a/examples/policy_change.py +++ b/examples/policy_change.py @@ -1,313 +1,308 @@ -"""Example: Analyze policy change impacts using ChangeAggregate. +"""Example: Analyse policy change impacts using ChangeAggregate with parametric reforms. This script demonstrates: -1. Creating baseline and reform datasets with different income distributions -2. Using ChangeAggregate to analyze winners, losers, and impact sizes -3. Filtering by absolute and relative change thresholds -4. Visualising results with Plotly +1. Loading representative household microdata +2. Applying parametric reforms (e.g., setting personal allowance to zero) +3. Running baseline and reform simulations +4. Using ChangeAggregate to analyse winners, losers, and impact sizes by income decile +5. Using quantile filters for decile-based analysis +6. Visualising results with Plotly Run: python examples/policy_change.py """ -import numpy as np -import pandas as pd -from microdf import MicroDataFrame + +from pathlib import Path +import datetime import plotly.graph_objects as go from plotly.subplots import make_subplots -from policyengine.core import Simulation +from policyengine.core import Simulation, Policy, Parameter, ParameterValue from policyengine.tax_benefit_models.uk import ( PolicyEngineUKDataset, - UKYearData, uk_latest, ) -from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType - -# Create baseline dataset with random incomes -np.random.seed(42) -n_people = 1000 - -baseline_person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": range(1, n_people + 1), - "benunit_id": range(1, n_people + 1), - "household_id": range(1, n_people + 1), - "age": np.random.randint(18, 70, n_people), - "employment_income": np.random.exponential(35000, n_people), - "person_weight": np.ones(n_people), - }), - weights="person_weight" +from policyengine.outputs.change_aggregate import ( + ChangeAggregate, + ChangeAggregateType, ) -benunit_df = MicroDataFrame( - pd.DataFrame({ - "benunit_id": range(1, n_people + 1), - "benunit_weight": np.ones(n_people), - }), - weights="benunit_weight" -) -household_df = MicroDataFrame( - pd.DataFrame({ - "household_id": range(1, n_people + 1), - "household_weight": np.ones(n_people), - }), - weights="household_weight" -) +def load_representative_data(year: int = 2026) -> PolicyEngineUKDataset: + """Load representative household microdata for a given year.""" + dataset_path = Path(f"./data/enhanced_frs_2023_24_year_{year}.h5") -# Create baseline dataset -baseline_dataset = PolicyEngineUKDataset( - name="Baseline", - description="Baseline scenario", - filepath="./baseline_data.h5", - year=2024, - data=UKYearData(person=baseline_person_df, benunit=benunit_df, household=household_df), -) + if not dataset_path.exists(): + raise FileNotFoundError( + f"Dataset not found at {dataset_path}. " + "Run create_datasets() from policyengine.tax_benefit_models.uk first." + ) -# Create reform dataset - progressive income boost -# Low earners get 10% boost, high earners get 5% boost, middle gets 7.5% -baseline_incomes = baseline_person_df["employment_income"].values -reform_incomes = [] -for income in baseline_incomes: - if income < 25000: - boost = 0.10 # 10% for low earners - elif income < 50000: - boost = 0.075 # 7.5% for middle earners - else: - boost = 0.05 # 5% for high earners - reform_incomes.append(income * (1 + boost)) - -reform_person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": range(1, n_people + 1), - "benunit_id": range(1, n_people + 1), - "household_id": range(1, n_people + 1), - "age": baseline_person_df["age"].values, - "employment_income": reform_incomes, - "person_weight": np.ones(n_people), - }), - weights="person_weight" -) + dataset = PolicyEngineUKDataset( + name=f"Enhanced FRS {year}", + description=f"Representative household microdata for {year}", + filepath=str(dataset_path), + year=year, + ) + dataset.load() + return dataset + + +def create_personal_allowance_reform(year: int) -> Policy: + """Create a policy that sets personal allowance to zero.""" + parameter = Parameter( + id=f"{uk_latest.id}-gov.hmrc.income_tax.allowances.personal_allowance.amount", + name="gov.hmrc.income_tax.allowances.personal_allowance.amount", + tax_benefit_model_version=uk_latest, + description="Personal allowance for income tax", + data_type=float, + ) -reform_dataset = PolicyEngineUKDataset( - name="Reform", - description="Progressive income boost", - filepath="./reform_data.h5", - year=2024, - data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), -) + parameter_value = ParameterValue( + parameter=parameter, + start_date=datetime.date(year, 1, 1), + end_date=datetime.date(year, 12, 31), + value=0, + ) -# Create simulations -baseline_sim = Simulation( - dataset=baseline_dataset, - tax_benefit_model_version=uk_latest, - output_dataset=baseline_dataset, -) + return Policy( + name="Zero personal allowance", + description="Sets personal allowance to £0", + parameter_values=[parameter_value], + ) -reform_sim = Simulation( - dataset=reform_dataset, - tax_benefit_model_version=uk_latest, - output_dataset=reform_dataset, -) -# Analysis 1: Overall winners/losers -winners = ChangeAggregate( - baseline_simulation=baseline_sim, - reform_simulation=reform_sim, - variable="employment_income", - aggregate_type=ChangeAggregateType.COUNT, - change_geq=1, -) -winners.run() - -losers = ChangeAggregate( - baseline_simulation=baseline_sim, - reform_simulation=reform_sim, - variable="employment_income", - aggregate_type=ChangeAggregateType.COUNT, - change_leq=-1, -) -losers.run() - -no_change = ChangeAggregate( - baseline_simulation=baseline_sim, - reform_simulation=reform_sim, - variable="employment_income", - aggregate_type=ChangeAggregateType.COUNT, - change_eq=0, -) -no_change.run() - -# Analysis 2: Total gains and losses -total_gain = ChangeAggregate( - baseline_simulation=baseline_sim, - reform_simulation=reform_sim, - variable="employment_income", - aggregate_type=ChangeAggregateType.SUM, - change_geq=0, -) -total_gain.run() - -total_loss = ChangeAggregate( - baseline_simulation=baseline_sim, - reform_simulation=reform_sim, - variable="employment_income", - aggregate_type=ChangeAggregateType.SUM, - change_leq=0, -) -total_loss.run() - -# Analysis 3: Distribution of gains by size -gain_bands = [ - ("£0-500", 0, 500), - ("£500-1k", 500, 1000), - ("£1k-2k", 1000, 2000), - ("£2k-3k", 2000, 3000), - ("£3k-5k", 3000, 5000), - ("£5k+", 5000, None), -] - -gain_counts = [] -for label, lower, upper in gain_bands: - agg = ChangeAggregate( +def run_baseline_simulation(dataset: PolicyEngineUKDataset) -> Simulation: + """Run baseline microsimulation without policy changes.""" + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + ) + simulation.run() + return simulation + + +def run_reform_simulation( + dataset: PolicyEngineUKDataset, policy: Policy +) -> Simulation: + """Run reform microsimulation with policy changes.""" + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + policy=policy, + ) + simulation.run() + return simulation + + +def analyse_overall_impact( + baseline_sim: Simulation, reform_sim: Simulation +) -> dict: + """Analyse overall winners, losers, and financial impact.""" + winners = ChangeAggregate( baseline_simulation=baseline_sim, reform_simulation=reform_sim, - variable="employment_income", + variable="household_net_income", aggregate_type=ChangeAggregateType.COUNT, - change_geq=lower, - change_leq=upper, + change_geq=1, ) - agg.run() - gain_counts.append(agg.result) - -# Analysis 4: Impact by age group -age_groups = [ - ("18-30", 18, 30), - ("31-45", 31, 45), - ("46-60", 46, 60), - ("61+", 61, 150), -] - -age_group_labels = [] -age_group_winners = [] -age_group_avg_gain = [] - -for label, min_age, max_age in age_groups: - # Count winners in age group - count_agg = ChangeAggregate( + winners.run() + + losers = ChangeAggregate( baseline_simulation=baseline_sim, reform_simulation=reform_sim, - variable="employment_income", + variable="household_net_income", aggregate_type=ChangeAggregateType.COUNT, - change_geq=1, - filter_variable="age", - filter_variable_geq=min_age, - filter_variable_leq=max_age, + change_leq=-1, ) - count_agg.run() + losers.run() - # Average gain in age group - mean_agg = ChangeAggregate( + no_change = ChangeAggregate( baseline_simulation=baseline_sim, reform_simulation=reform_sim, - variable="employment_income", - aggregate_type=ChangeAggregateType.MEAN, - filter_variable="age", - filter_variable_geq=min_age, - filter_variable_leq=max_age, + variable="household_net_income", + aggregate_type=ChangeAggregateType.COUNT, + change_eq=0, ) - mean_agg.run() - - age_group_labels.append(label) - age_group_winners.append(count_agg.result) - age_group_avg_gain.append(mean_agg.result) - -# Analysis 5: Large winners (gaining more than 10%) -large_winners = ChangeAggregate( - baseline_simulation=baseline_sim, - reform_simulation=reform_sim, - variable="employment_income", - aggregate_type=ChangeAggregateType.COUNT, - relative_change_geq=0.10, -) -large_winners.run() - -# Create visualisations -fig = make_subplots( - rows=2, cols=2, - subplot_titles=( - "Winners vs losers", - "Distribution of gains", - "Winners by age group", - "Average gain by age group" - ), - specs=[[{"type": "bar"}, {"type": "bar"}], - [{"type": "bar"}, {"type": "bar"}]] -) + no_change.run() -fig.add_trace( - go.Bar( - x=["Winners", "No change", "Losers"], - y=[winners.result, no_change.result, losers.result], - marker_color=["green", "gray", "red"] - ), - row=1, col=1 -) + total_change = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="household_net_income", + aggregate_type=ChangeAggregateType.SUM, + ) + total_change.run() -fig.add_trace( - go.Bar( - x=[label for label, _, _ in gain_bands], - y=gain_counts, - marker_color="lightblue" - ), - row=1, col=2 -) + tax_revenue_change = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="household_tax", + aggregate_type=ChangeAggregateType.SUM, + ) + tax_revenue_change.run() + + return { + "winners": winners.result / 1e6, # Convert to millions + "losers": losers.result / 1e6, + "no_change": no_change.result / 1e6, + "total_change": total_change.result / 1e9, # Convert to billions + "tax_revenue_change": tax_revenue_change.result / 1e9, + } + + +def analyse_impact_by_income_decile( + baseline_sim: Simulation, reform_sim: Simulation +) -> dict: + """Analyse impact by income decile.""" + decile_labels = [] + decile_losers = [] + decile_avg_loss = [] + + for decile in range(1, 11): + label = f"Decile {decile}" + + # Count losers in this decile + count_agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="household_net_income", + aggregate_type=ChangeAggregateType.COUNT, + change_leq=-1, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile, + ) + count_agg.run() + + # Average loss for all households in this decile + mean_agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="household_net_income", + aggregate_type=ChangeAggregateType.MEAN, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile, + ) + mean_agg.run() + + decile_labels.append(label) + decile_losers.append(count_agg.result / 1e6) # Convert to millions + decile_avg_loss.append(mean_agg.result) + + return { + "labels": decile_labels, + "losers": decile_losers, + "avg_loss": decile_avg_loss, + } + + +def visualise_results( + overall: dict, by_decile: dict, reform_name: str +) -> None: + """Create visualisations of policy change impacts.""" + fig = make_subplots( + rows=1, + cols=3, + subplot_titles=( + "Winners vs losers (millions)", + "Losers by income decile (millions)", + "Average loss by income decile (£)", + ), + specs=[[{"type": "bar"}, {"type": "bar"}, {"type": "bar"}]], + ) -fig.add_trace( - go.Bar( - x=age_group_labels, - y=age_group_winners, - marker_color="lightgreen" - ), - row=2, col=1 -) + fig.add_trace( + go.Bar( + x=["Winners", "No change", "Losers"], + y=[ + overall["winners"], + overall["no_change"], + overall["losers"], + ], + marker_color=["green", "gray", "red"], + ), + row=1, + col=1, + ) -fig.add_trace( - go.Bar( - x=age_group_labels, - y=age_group_avg_gain, - marker_color="orange" - ), - row=2, col=2 -) + fig.add_trace( + go.Bar( + x=by_decile["labels"], + y=by_decile["losers"], + marker_color="lightcoral", + ), + row=1, + col=2, + ) -fig.update_xaxes(title_text="Category", row=1, col=1) -fig.update_xaxes(title_text="Gain amount", row=1, col=2) -fig.update_xaxes(title_text="Age group", row=2, col=1) -fig.update_xaxes(title_text="Age group", row=2, col=2) + fig.add_trace( + go.Bar( + x=by_decile["labels"], + y=by_decile["avg_loss"], + marker_color="orange", + ), + row=1, + col=3, + ) -fig.update_yaxes(title_text="Number of people", row=1, col=1) -fig.update_yaxes(title_text="Number of people", row=1, col=2) -fig.update_yaxes(title_text="Number of winners", row=2, col=1) -fig.update_yaxes(title_text="Average gain (£)", row=2, col=2) + fig.update_xaxes(title_text="Category", row=1, col=1) + fig.update_xaxes(title_text="Income decile", row=1, col=2) + fig.update_xaxes(title_text="Income decile", row=1, col=3) + + fig.update_layout( + title_text=f"Policy change impact analysis: {reform_name}", + showlegend=False, + height=400, + ) + + fig.show() + + +def print_summary(overall: dict, decile: dict, reform_name: str) -> None: + """Print summary statistics.""" + print("=" * 60) + print(f"Policy change impact summary: {reform_name}") + print("=" * 60) + print(f"\nOverall impact:") + print(f" Winners: {overall['winners']:.2f}m households") + print(f" Losers: {overall['losers']:.2f}m households") + print(f" No change: {overall['no_change']:.2f}m households") + print(f"\nFinancial impact:") + print(f" Net income change: £{overall['total_change']:.2f}bn (negative = loss)") + print(f" Tax revenue change: £{overall['tax_revenue_change']:.2f}bn") + print(f"\nImpact by income decile:") + for i, label in enumerate(decile['labels']): + print(f" {label}: {decile['losers'][i]:.2f}m losers, avg change £{decile['avg_loss'][i]:.0f}") + print("=" * 60) + + +def main(): + """Main execution function.""" + year = 2026 + + print("Loading representative household data...") + dataset = load_representative_data(year=year) + + print("Creating policy reform (zero personal allowance)...") + reform = create_personal_allowance_reform(year=year) + + print("Running baseline simulation...") + baseline_sim = run_baseline_simulation(dataset) + + print("Running reform simulation...") + reform_sim = run_reform_simulation(dataset, reform) + + print("Analysing overall impact...") + overall_impact = analyse_overall_impact(baseline_sim, reform_sim) + + print("Analysing impact by income decile...") + decile_impact = analyse_impact_by_income_decile(baseline_sim, reform_sim) + + print_summary(overall_impact, decile_impact, reform.name) + + print("\nGenerating visualisations...") + visualise_results(overall_impact, decile_impact, reform.name) -fig.update_layout( - title_text="Policy change impact analysis", - showlegend=False, - height=800, -) -# Print summary statistics -print("=" * 60) -print("Policy change impact summary") -print("=" * 60) -print(f"\nOverall impact:") -print(f" Winners: {winners.result:,.0f} people") -print(f" Losers: {losers.result:,.0f} people") -print(f" No change: {no_change.result:,.0f} people") -print(f"\nFinancial impact:") -print(f" Total gains: £{total_gain.result:,.0f}") -print(f" Total losses: £{total_loss.result:,.0f}") -print(f" Net change: £{total_gain.result + total_loss.result:,.0f}") -print(f"\nLarge winners (>10% gain): {large_winners.result:,.0f} people") -print("=" * 60) - -fig.show() +if __name__ == "__main__": + main() diff --git a/src/policyengine/outputs/__init__.py b/src/policyengine/outputs/__init__.py index 8b255697..fc35bb27 100644 --- a/src/policyengine/outputs/__init__.py +++ b/src/policyengine/outputs/__init__.py @@ -1,9 +1,14 @@ +from policyengine.outputs.base import Output from policyengine.outputs.aggregate import Aggregate, AggregateType from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType +from policyengine.outputs.decile_impact import DecileImpact, calculate_decile_impacts __all__ = [ + "Output", "Aggregate", "AggregateType", "ChangeAggregate", "ChangeAggregateType", + "DecileImpact", + "calculate_decile_impacts", ] diff --git a/src/policyengine/outputs/aggregate.py b/src/policyengine/outputs/aggregate.py index 5c0c08c1..42e408c2 100644 --- a/src/policyengine/outputs/aggregate.py +++ b/src/policyengine/outputs/aggregate.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel, Field from policyengine.core import * +from policyengine.outputs.base import Output from enum import Enum from typing import Any @@ -10,7 +10,7 @@ class AggregateType(str, Enum): COUNT = "count" -class Aggregate(BaseModel): +class Aggregate(Output): simulation: Simulation variable: str aggregate_type: AggregateType @@ -22,9 +22,27 @@ class Aggregate(BaseModel): filter_variable_geq: Any | None = None filter_variable_describes_quantiles: bool = False + # Convenient quantile specification (alternative to describes_quantiles) + quantile: int | None = None # Number of quantiles (e.g., 10 for deciles, 5 for quintiles) + quantile_eq: int | None = None # Exact quantile (e.g., 3 for 3rd decile) + quantile_leq: int | None = None # Maximum quantile (e.g., 5 for bottom 5 deciles) + quantile_geq: int | None = None # Minimum quantile (e.g., 9 for top 2 deciles) + result: Any | None = None def run(self): + # Convert quantile specification to describes_quantiles format + if self.quantile is not None: + self.filter_variable_describes_quantiles = True + if self.quantile_eq is not None: + # For a specific quantile, filter between (quantile-1)/n and quantile/n + self.filter_variable_geq = (self.quantile_eq - 1) / self.quantile + self.filter_variable_leq = self.quantile_eq / self.quantile + elif self.quantile_leq is not None: + self.filter_variable_leq = self.quantile_leq / self.quantile + elif self.quantile_geq is not None: + self.filter_variable_geq = (self.quantile_geq - 1) / self.quantile + # Get variable object var_obj = next( v diff --git a/src/policyengine/outputs/base.py b/src/policyengine/outputs/base.py new file mode 100644 index 00000000..46e2b46c --- /dev/null +++ b/src/policyengine/outputs/base.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + + +class Output(BaseModel): + """Base class for all output templates.""" + + def run(self): + """Calculate and populate the output fields. + + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement run()") diff --git a/src/policyengine/outputs/change_aggregate.py b/src/policyengine/outputs/change_aggregate.py index 57a01479..89975abf 100644 --- a/src/policyengine/outputs/change_aggregate.py +++ b/src/policyengine/outputs/change_aggregate.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel from policyengine.core import * +from policyengine.outputs.base import Output from enum import Enum from typing import Any @@ -10,7 +10,7 @@ class ChangeAggregateType(str, Enum): MEAN = "mean" -class ChangeAggregate(BaseModel): +class ChangeAggregate(Output): baseline_simulation: Simulation reform_simulation: Simulation variable: str @@ -32,10 +32,29 @@ class ChangeAggregate(BaseModel): filter_variable_eq: Any | None = None filter_variable_leq: Any | None = None filter_variable_geq: Any | None = None + filter_variable_describes_quantiles: bool = False + + # Convenient quantile specification (alternative to describes_quantiles) + quantile: int | None = None # Number of quantiles (e.g., 10 for deciles, 5 for quintiles) + quantile_eq: int | None = None # Exact quantile (e.g., 3 for 3rd decile) + quantile_leq: int | None = None # Maximum quantile (e.g., 5 for bottom 5 deciles) + quantile_geq: int | None = None # Minimum quantile (e.g., 9 for top 2 deciles) result: Any | None = None def run(self): + # Convert quantile specification to describes_quantiles format + if self.quantile is not None: + self.filter_variable_describes_quantiles = True + if self.quantile_eq is not None: + # For a specific quantile, filter between (quantile-1)/n and quantile/n + self.filter_variable_geq = (self.quantile_eq - 1) / self.quantile + self.filter_variable_leq = self.quantile_eq / self.quantile + elif self.quantile_leq is not None: + self.filter_variable_leq = self.quantile_leq / self.quantile + elif self.quantile_geq is not None: + self.filter_variable_geq = (self.quantile_geq - 1) / self.quantile + # Get variable object var_obj = next( v for v in self.baseline_simulation.tax_benefit_model_version.variables @@ -106,12 +125,23 @@ def run(self): else: filter_series = baseline_data[self.filter_variable] - if self.filter_variable_eq is not None: - mask &= (filter_series == self.filter_variable_eq) - if self.filter_variable_leq is not None: - mask &= (filter_series <= self.filter_variable_leq) - if self.filter_variable_geq is not None: - mask &= (filter_series >= self.filter_variable_geq) + if self.filter_variable_describes_quantiles: + if self.filter_variable_eq is not None: + threshold = filter_series.quantile(self.filter_variable_eq) + mask &= (filter_series <= threshold) + if self.filter_variable_leq is not None: + threshold = filter_series.quantile(self.filter_variable_leq) + mask &= (filter_series <= threshold) + if self.filter_variable_geq is not None: + threshold = filter_series.quantile(self.filter_variable_geq) + mask &= (filter_series >= threshold) + else: + if self.filter_variable_eq is not None: + mask &= (filter_series == self.filter_variable_eq) + if self.filter_variable_leq is not None: + mask &= (filter_series <= self.filter_variable_leq) + if self.filter_variable_geq is not None: + mask &= (filter_series >= self.filter_variable_geq) # Apply mask to get filtered data filtered_change = change_series[mask] diff --git a/src/policyengine/outputs/decile_impact.py b/src/policyengine/outputs/decile_impact.py new file mode 100644 index 00000000..5249b2a8 --- /dev/null +++ b/src/policyengine/outputs/decile_impact.py @@ -0,0 +1,122 @@ +from policyengine.core import Simulation +from policyengine.outputs.base import Output +from pydantic import ConfigDict +import pandas as pd + + +class DecileImpact(Output): + """Single decile's impact from a policy reform - represents one database row.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + baseline_simulation: Simulation + reform_simulation: Simulation + income_variable: str = "equiv_hbai_household_net_income" + entity: str | None = None + decile: int + quantiles: int = 10 + + # Results populated by run() + baseline_mean: float | None = None + reform_mean: float | None = None + absolute_change: float | None = None + relative_change: float | None = None + count_better_off: float | None = None + count_worse_off: float | None = None + count_no_change: float | None = None + + def run(self): + """Calculate impact for this specific decile.""" + # Get variable object to determine entity + var_obj = next( + v + for v in self.baseline_simulation.tax_benefit_model_version.variables + if v.name == self.income_variable + ) + + # Get target entity + target_entity = self.entity or var_obj.entity + + # Get data from both simulations + baseline_data = getattr(self.baseline_simulation.output_dataset.data, target_entity) + reform_data = getattr(self.reform_simulation.output_dataset.data, target_entity) + + # Map income variable to target entity if needed + if var_obj.entity != target_entity: + baseline_mapped = self.baseline_simulation.output_dataset.data.map_to_entity( + var_obj.entity, target_entity + ) + baseline_income = baseline_mapped[self.income_variable] + + reform_mapped = self.reform_simulation.output_dataset.data.map_to_entity( + var_obj.entity, target_entity + ) + reform_income = reform_mapped[self.income_variable] + else: + baseline_income = baseline_data[self.income_variable] + reform_income = reform_data[self.income_variable] + + # Calculate deciles based on baseline income + decile_series = pd.qcut(baseline_income, self.quantiles, labels=False, duplicates='drop') + 1 + + # Calculate changes + absolute_change = reform_income - baseline_income + relative_change = (absolute_change / baseline_income) * 100 + + # Filter to this decile + mask = (decile_series == self.decile) + + # Populate results + self.baseline_mean = float(baseline_income[mask].mean()) + self.reform_mean = float(reform_income[mask].mean()) + self.absolute_change = float(absolute_change[mask].mean()) + self.relative_change = float(relative_change[mask].mean()) + self.count_better_off = float((absolute_change[mask] > 0).sum()) + self.count_worse_off = float((absolute_change[mask] < 0).sum()) + self.count_no_change = float((absolute_change[mask] == 0).sum()) + + +def calculate_decile_impacts( + baseline_simulation: Simulation, + reform_simulation: Simulation, + income_variable: str = "equiv_hbai_household_net_income", + entity: str | None = None, + quantiles: int = 10, +) -> tuple[list[DecileImpact], pd.DataFrame]: + """Calculate decile-by-decile impact of a reform. + + Returns: + tuple of (list of DecileImpact objects, DataFrame) + """ + results = [] + for decile in range(1, quantiles + 1): + impact = DecileImpact( + baseline_simulation=baseline_simulation, + reform_simulation=reform_simulation, + income_variable=income_variable, + entity=entity, + decile=decile, + quantiles=quantiles, + ) + impact.run() + results.append(impact) + + # Also create DataFrame for convenience + df = pd.DataFrame([ + { + "baseline_simulation_id": r.baseline_simulation.id, + "reform_simulation_id": r.reform_simulation.id, + "income_variable": r.income_variable, + "decile": r.decile, + "baseline_mean": r.baseline_mean, + "reform_mean": r.reform_mean, + "absolute_change": r.absolute_change, + "relative_change": r.relative_change, + "count_better_off": r.count_better_off, + "count_worse_off": r.count_worse_off, + "count_no_change": r.count_no_change, + } + for r in results + ]) + + return results, df diff --git a/src/policyengine/tax_benefit_models/uk.py b/src/policyengine/tax_benefit_models/uk.py index 7ec9fa00..a070033e 100644 --- a/src/policyengine/tax_benefit_models/uk.py +++ b/src/policyengine/tax_benefit_models/uk.py @@ -1,456 +1,19 @@ -from policyengine.core import * -from pydantic import BaseModel, Field, ConfigDict -import pandas as pd -from typing import Dict -import datetime -import requests -from policyengine.utils import parse_safe_date -from pathlib import Path -from importlib.metadata import version -from microdf import MicroDataFrame - - -class UKYearData(BaseModel): - """Entity-level data for a single year.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - person: MicroDataFrame - benunit: MicroDataFrame - household: MicroDataFrame - - def map_to_entity( - self, source_entity: str, target_entity: str, columns: list[str] = None - ) -> MicroDataFrame: - """Map data from source entity to target entity using join keys. - - Args: - source_entity (str): The source entity name ('person', 'benunit', 'household'). - target_entity (str): The target entity name ('person', 'benunit', 'household'). - columns (list[str], optional): List of column names to map. If None, maps all columns. - - Returns: - MicroDataFrame: The mapped data at the target entity level. - - Raises: - ValueError: If source or target entity is invalid. - """ - valid_entities = {"person", "benunit", "household"} - if source_entity not in valid_entities: - raise ValueError( - f"Invalid source entity '{source_entity}'. Must be one of {valid_entities}" - ) - if target_entity not in valid_entities: - raise ValueError( - f"Invalid target entity '{target_entity}'. Must be one of {valid_entities}" - ) - - # Get source data - source_df = getattr(self, source_entity) - if columns: - # Select only requested columns (keep join keys) - join_keys = {"person_id", "benunit_id", "household_id"} - cols_to_keep = list( - set(columns) | (join_keys & set(source_df.columns)) - ) - source_df = source_df[cols_to_keep] - - # Determine weight column for target entity - weight_col_map = { - "person": "person_weight", - "benunit": "benunit_weight", - "household": "household_weight", - } - target_weight = weight_col_map[target_entity] - - # Same entity - return as is - if source_entity == target_entity: - return MicroDataFrame( - pd.DataFrame(source_df), weights=target_weight - ) - - # Map to different entity - target_df = getattr(self, target_entity) - - # Person -> Benunit - if source_entity == "person" and target_entity == "benunit": - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on="benunit_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Person -> Household - elif source_entity == "person" and target_entity == "household": - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on="household_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Benunit -> Person - elif source_entity == "benunit" and target_entity == "person": - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on="benunit_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Benunit -> Household - elif source_entity == "benunit" and target_entity == "household": - # Need to go through person to link benunit and household - person_link = pd.DataFrame(self.person)[ - ["benunit_id", "household_id"] - ].drop_duplicates() - source_with_hh = pd.DataFrame(source_df).merge( - person_link, on="benunit_id", how="left" - ) - result = pd.DataFrame(target_df).merge( - source_with_hh, on="household_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Household -> Person - elif source_entity == "household" and target_entity == "person": - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on="household_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Household -> Benunit - elif source_entity == "household" and target_entity == "benunit": - # Need to go through person to link household and benunit - person_link = pd.DataFrame(self.person)[ - ["benunit_id", "household_id"] - ].drop_duplicates() - source_with_bu = pd.DataFrame(source_df).merge( - person_link, on="household_id", how="left" - ) - result = pd.DataFrame(target_df).merge( - source_with_bu, on="benunit_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - else: - raise ValueError( - f"Unsupported mapping from {source_entity} to {target_entity}" - ) - - -class PolicyEngineUKDataset(Dataset): - """UK dataset with multi-year entity-level data.""" - - data: UKYearData | None = None - - def __init__(self, **kwargs: dict): - super().__init__(**kwargs) - - # Make sure we are synchronised between in-memory and storage, at least on initialisation. - if "data" in kwargs: - self.save() - elif "filepath" in kwargs: - self.load() - - def save(self) -> None: - """Save dataset to HDF5 file.""" - filepath = Path(self.filepath) - if not filepath.parent.exists(): - filepath.parent.mkdir(parents=True, exist_ok=True) - with pd.HDFStore(filepath, mode="w") as store: - store["person"] = pd.DataFrame(self.data.person) - store["benunit"] = pd.DataFrame(self.data.benunit) - store["household"] = pd.DataFrame(self.data.household) - - def load(self) -> None: - """Load dataset from HDF5 file into this instance.""" - filepath = self.filepath - with pd.HDFStore(filepath, mode="r") as store: - self.data = UKYearData( - person=MicroDataFrame( - store["person"], weights="person_weight" - ), - benunit=MicroDataFrame( - store["benunit"], weights="benunit_weight" - ), - household=MicroDataFrame( - store["household"], weights="household_weight" - ), - ) - - def __repr__(self) -> str: - if self.data is None: - return f"" - else: - n_people = len(self.data.person) - n_benunits = len(self.data.benunit) - n_households = len(self.data.household) - return f"" - - -class PolicyEngineUK(TaxBenefitModel): - id: str = "policyengine-uk" - description: str = "The UK's open-source dynamic tax and benefit microsimulation model maintained by PolicyEngine." - - -uk_model = PolicyEngineUK() - -pkg_version = version("policyengine-uk") - -# Get published time from PyPI -response = requests.get("https://pypi.org/pypi/policyengine-uk/json") -data = response.json() -upload_time = data["releases"][pkg_version][0]["upload_time_iso_8601"] - - -class PolicyEngineUKLatest(TaxBenefitModelVersion): - model: TaxBenefitModel = uk_model - version: str = pkg_version - created_at: datetime.datetime = datetime.datetime.fromisoformat( - upload_time - ) - - def __init__(self, **kwargs: dict): - super().__init__(**kwargs) - from policyengine_uk.system import system - from policyengine_core.enums import Enum - - self.id = f"{self.model.id}@{self.version}" - - self.variables = [] - for var_obj in system.variables.values(): - variable = Variable( - id=self.id + "-" + var_obj.name, - name=var_obj.name, - tax_benefit_model_version=self, - entity=var_obj.entity.key, - description=var_obj.documentation, - data_type=var_obj.value_type - if var_obj.value_type is not Enum - else str, - ) - if ( - hasattr(var_obj, "possible_values") - and var_obj.possible_values is not None - ): - variable.possible_values = list( - map( - lambda x: x.name, - var_obj.possible_values._value2member_map_.values(), - ) - ) - self.variables.append(variable) - - self.parameters = [] - from policyengine_core.parameters import Parameter as CoreParameter - - for param_node in system.parameters.get_descendants(): - if isinstance(param_node, CoreParameter): - parameter = Parameter( - id=self.id + "-" + param_node.name, - name=param_node.name, - tax_benefit_model_version=self, - description=param_node.description, - data_type=type( - param_node(2025) - ), # Example year to infer type - unit=param_node.metadata.get("unit"), - ) - self.parameters.append(parameter) - - for i in range(len(param_node.values_list)): - param_at_instant = param_node.values_list[i] - if i + 1 < len(param_node.values_list): - next_instant = param_node.values_list[i + 1] - else: - next_instant = None - parameter_value = ParameterValue( - parameter=parameter, - start_date=parse_safe_date( - param_at_instant.instant_str - ), - end_date=parse_safe_date(next_instant.instant_str) - if next_instant - else None, - value=param_at_instant.value, - ) - self.parameter_values.append(parameter_value) - - def run(self, simulation: "Simulation") -> "Simulation": - from policyengine_uk import Microsimulation - from policyengine_uk.data import UKSingleYearDataset - from policyengine.utils.parametric_reforms import simulation_modifier_from_parameter_values - - assert isinstance(simulation.dataset, PolicyEngineUKDataset) - - dataset = simulation.dataset - dataset.load() - input_data = UKSingleYearDataset( - person=dataset.data.person, - benunit=dataset.data.benunit, - household=dataset.data.household, - fiscal_year=dataset.year, - ) - microsim = Microsimulation(dataset=input_data) - - if ( - simulation.policy - and simulation.policy.simulation_modifier is not None - ): - simulation.policy.simulation_modifier(microsim) - elif simulation.policy: - modifier = simulation_modifier_from_parameter_values( - simulation.policy.parameter_values - ) - modifier(microsim) - - if ( - simulation.dynamic - and simulation.dynamic.simulation_modifier is not None - ): - simulation.dynamic.simulation_modifier(microsim) - elif simulation.dynamic: - modifier = simulation_modifier_from_parameter_values( - simulation.dynamic.parameter_values - ) - modifier(microsim) - - # Allow custom variable selection, or use defaults - if simulation.variables is not None: - entity_variables = simulation.variables - else: - # Default comprehensive variable set - entity_variables = { - "person": [ - # IDs and weights - "person_id", - "benunit_id", - "household_id", - "person_weight", - # Demographics - "age", - "gender", - "is_adult", - "is_SP_age", - "is_child", - # Income - "employment_income", - "self_employment_income", - "pension_income", - "private_pension_income", - "savings_interest_income", - "dividend_income", - "property_income", - "total_income", - "earned_income", - # Benefits - "universal_credit", - "child_benefit", - "pension_credit", - "income_support", - "working_tax_credit", - "child_tax_credit", - # Tax - "income_tax", - "national_insurance", - ], - "benunit": [ - # IDs and weights - "benunit_id", - "benunit_weight", - # Structure - "family_type", - # Income and benefits - "universal_credit", - "child_benefit", - "working_tax_credit", - "child_tax_credit", - ], - "household": [ - # IDs and weights - "household_id", - "household_weight", - # Income measures - "household_net_income", - "hbai_household_net_income", - "equiv_hbai_household_net_income", - "household_market_income", - "household_gross_income", - # Benefits and tax - "household_benefits", - "household_tax", - "vat", - # Housing - "rent", - "council_tax", - "tenure_type", - ], - } - - data = { - "person": pd.DataFrame(), - "benunit": pd.DataFrame(), - "household": pd.DataFrame(), - } - - for entity, variables in entity_variables.items(): - for var in variables: - data[entity][var] = microsim.calculate( - var, period=simulation.dataset.year, map_to=entity - ).values - - data["person"] = MicroDataFrame( - data["person"], weights="person_weight" - ) - data["benunit"] = MicroDataFrame( - data["benunit"], weights="benunit_weight" - ) - data["household"] = MicroDataFrame( - data["household"], weights="household_weight" - ) - - simulation.output_dataset = PolicyEngineUKDataset( - name=dataset.name, - description=dataset.description, - filepath=str( - Path(simulation.dataset.filepath).parent - / (simulation.id + ".h5") - ), - year=simulation.dataset.year, - is_output_dataset=True, - data=UKYearData( - person=data["person"], - benunit=data["benunit"], - household=data["household"], - ), - ) - - simulation.output_dataset.save() - -def create_datasets( - datasets: list[str] = [ - "hf://policyengine/policyengine-uk-data/frs_2023_24.h5", - "hf://policyengine/policyengine-uk-data/enhanced_frs_2023_24.h5", - ], - years: list[int] = [2026, 2027, 2028, 2029, 2030], -) -> None: - for dataset in datasets: - from policyengine_uk import Microsimulation - sim = Microsimulation(dataset=dataset) - for year in years: - year_dataset = sim.dataset[year] - uk_dataset = PolicyEngineUKDataset( - name=f"{dataset}-year-{year}", - description=f"UK Dataset for year {year} based on {dataset}", - filepath=f"./data/{Path(dataset).stem}_year_{year}.h5", - year=year, - data=UKYearData( - person=MicroDataFrame(year_dataset.person), - benunit=MicroDataFrame(year_dataset.benunit), - household=MicroDataFrame(year_dataset.household), - ), - ) - uk_dataset.save() - +"""PolicyEngine UK tax-benefit model - imports from uk/ module.""" + +from .uk import * + +__all__ = [ + "UKYearData", + "PolicyEngineUKDataset", + "create_datasets", + "PolicyEngineUK", + "PolicyEngineUKLatest", + "uk_model", + "uk_latest", + "general_policy_reform_analysis", + "ProgrammeStatistics", +] # Rebuild models to resolve forward references PolicyEngineUKDataset.model_rebuild() PolicyEngineUKLatest.model_rebuild() - -uk_latest = PolicyEngineUKLatest() diff --git a/src/policyengine/tax_benefit_models/uk/__init__.py b/src/policyengine/tax_benefit_models/uk/__init__.py new file mode 100644 index 00000000..91f58000 --- /dev/null +++ b/src/policyengine/tax_benefit_models/uk/__init__.py @@ -0,0 +1,18 @@ +"""PolicyEngine UK tax-benefit model.""" + +from .datasets import UKYearData, PolicyEngineUKDataset, create_datasets +from .model import PolicyEngineUK, PolicyEngineUKLatest, uk_model, uk_latest +from .analysis import general_policy_reform_analysis +from .outputs import ProgrammeStatistics + +__all__ = [ + "UKYearData", + "PolicyEngineUKDataset", + "create_datasets", + "PolicyEngineUK", + "PolicyEngineUKLatest", + "uk_model", + "uk_latest", + "general_policy_reform_analysis", + "ProgrammeStatistics", +] diff --git a/src/policyengine/tax_benefit_models/uk/analysis.py b/src/policyengine/tax_benefit_models/uk/analysis.py new file mode 100644 index 00000000..f29c3f47 --- /dev/null +++ b/src/policyengine/tax_benefit_models/uk/analysis.py @@ -0,0 +1,79 @@ +"""General utility functions for UK policy reform analysis.""" + +from policyengine.core import Simulation +from policyengine.outputs.decile_impact import DecileImpact, calculate_decile_impacts +from .outputs import ProgrammeStatistics +import pandas as pd + + +def general_policy_reform_analysis( + baseline_simulation: Simulation, + reform_simulation: Simulation, +) -> tuple[list[DecileImpact], list[ProgrammeStatistics], pd.DataFrame, pd.DataFrame]: + """Perform comprehensive analysis of a policy reform. + + Returns: + tuple of: + - list[DecileImpact]: Decile-by-decile impacts + - list[ProgrammeStatistics]: Statistics for major programmes + - pd.DataFrame: Decile impacts as DataFrame + - pd.DataFrame: Programme statistics as DataFrame + """ + # Decile impact + decile_impacts, decile_df = calculate_decile_impacts( + baseline_simulation=baseline_simulation, + reform_simulation=reform_simulation, + ) + + # Major programmes to analyse + programmes = { + # Tax + "income_tax": {"entity": "person", "is_tax": True}, + "national_insurance": {"entity": "person", "is_tax": True}, + "vat": {"entity": "household", "is_tax": True}, + "council_tax": {"entity": "household", "is_tax": True}, + # Benefits + "universal_credit": {"entity": "person", "is_tax": False}, + "child_benefit": {"entity": "person", "is_tax": False}, + "pension_credit": {"entity": "person", "is_tax": False}, + "income_support": {"entity": "person", "is_tax": False}, + "working_tax_credit": {"entity": "person", "is_tax": False}, + "child_tax_credit": {"entity": "person", "is_tax": False}, + } + + programme_statistics = [] + + for programme_name, programme_info in programmes.items(): + entity = programme_info["entity"] + is_tax = programme_info["is_tax"] + + stats = ProgrammeStatistics( + baseline_simulation=baseline_simulation, + reform_simulation=reform_simulation, + programme_name=programme_name, + entity=entity, + is_tax=is_tax, + ) + stats.run() + programme_statistics.append(stats) + + # Create DataFrames for convenience + programme_df = pd.DataFrame([ + { + "baseline_simulation_id": p.baseline_simulation.id, + "reform_simulation_id": p.reform_simulation.id, + "programme_name": p.programme_name, + "entity": p.entity, + "is_tax": p.is_tax, + "baseline_total": p.baseline_total, + "reform_total": p.reform_total, + "change": p.change, + "baseline_count": p.baseline_count, + "reform_count": p.reform_count, + "winners": p.winners, + "losers": p.losers, + } + for p in programme_statistics + ]) + + return decile_impacts, programme_statistics, decile_df, programme_df diff --git a/src/policyengine/tax_benefit_models/uk/datasets.py b/src/policyengine/tax_benefit_models/uk/datasets.py new file mode 100644 index 00000000..073d3a95 --- /dev/null +++ b/src/policyengine/tax_benefit_models/uk/datasets.py @@ -0,0 +1,238 @@ +from policyengine.core import Dataset +from pydantic import BaseModel, ConfigDict +import pandas as pd +from microdf import MicroDataFrame +from pathlib import Path + + +class UKYearData(BaseModel): + """Entity-level data for a single year.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + person: MicroDataFrame + benunit: MicroDataFrame + household: MicroDataFrame + + def map_to_entity( + self, source_entity: str, target_entity: str, columns: list[str] = None + ) -> MicroDataFrame: + """Map data from source entity to target entity using join keys. + + Args: + source_entity (str): The source entity name ('person', 'benunit', 'household'). + target_entity (str): The target entity name ('person', 'benunit', 'household'). + columns (list[str], optional): List of column names to map. If None, maps all columns. + + Returns: + MicroDataFrame: The mapped data at the target entity level. + + Raises: + ValueError: If source or target entity is invalid. + """ + valid_entities = {"person", "benunit", "household"} + if source_entity not in valid_entities: + raise ValueError( + f"Invalid source entity '{source_entity}'. Must be one of {valid_entities}" + ) + if target_entity not in valid_entities: + raise ValueError( + f"Invalid target entity '{target_entity}'. Must be one of {valid_entities}" + ) + + # Get source data + source_df = getattr(self, source_entity) + if columns: + # Select only requested columns (keep join keys) + join_keys = {"person_id", "benunit_id", "household_id"} + cols_to_keep = list( + set(columns) | (join_keys & set(source_df.columns)) + ) + source_df = source_df[cols_to_keep] + + # Determine weight column for target entity + weight_col_map = { + "person": "person_weight", + "benunit": "benunit_weight", + "household": "household_weight", + } + target_weight = weight_col_map[target_entity] + + # Same entity - return as is + if source_entity == target_entity: + return MicroDataFrame( + pd.DataFrame(source_df), weights=target_weight + ) + + # Map to different entity + target_df = getattr(self, target_entity) + + # Person -> Benunit + if source_entity == "person" and target_entity == "benunit": + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on="benunit_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Person -> Household + elif source_entity == "person" and target_entity == "household": + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on="household_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Benunit -> Person + elif source_entity == "benunit" and target_entity == "person": + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on="benunit_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Benunit -> Household + elif source_entity == "benunit" and target_entity == "household": + # Need to go through person to link benunit and household + person_link = pd.DataFrame(self.person)[ + ["benunit_id", "household_id"] + ].drop_duplicates() + source_with_hh = pd.DataFrame(source_df).merge( + person_link, on="benunit_id", how="left" + ) + result = pd.DataFrame(target_df).merge( + source_with_hh, on="household_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Household -> Person + elif source_entity == "household" and target_entity == "person": + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on="household_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Household -> Benunit + elif source_entity == "household" and target_entity == "benunit": + # Need to go through person to link household and benunit + person_link = pd.DataFrame(self.person)[ + ["benunit_id", "household_id"] + ].drop_duplicates() + source_with_bu = pd.DataFrame(source_df).merge( + person_link, on="household_id", how="left" + ) + result = pd.DataFrame(target_df).merge( + source_with_bu, on="benunit_id", how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + else: + raise ValueError( + f"Unsupported mapping from {source_entity} to {target_entity}" + ) + + +class PolicyEngineUKDataset(Dataset): + """UK dataset with multi-year entity-level data.""" + + data: UKYearData | None = None + + def __init__(self, **kwargs: dict): + super().__init__(**kwargs) + + # Make sure we are synchronised between in-memory and storage, at least on initialisation + if "data" in kwargs: + self.save() + elif "filepath" in kwargs: + self.load() + + def save(self) -> None: + """Save dataset to HDF5 file.""" + filepath = Path(self.filepath) + if not filepath.parent.exists(): + filepath.parent.mkdir(parents=True, exist_ok=True) + with pd.HDFStore(filepath, mode="w") as store: + store["person"] = pd.DataFrame(self.data.person) + store["benunit"] = pd.DataFrame(self.data.benunit) + store["household"] = pd.DataFrame(self.data.household) + + def load(self) -> None: + """Load dataset from HDF5 file into this instance.""" + filepath = self.filepath + with pd.HDFStore(filepath, mode="r") as store: + self.data = UKYearData( + person=MicroDataFrame( + store["person"], weights="person_weight" + ), + benunit=MicroDataFrame( + store["benunit"], weights="benunit_weight" + ), + household=MicroDataFrame( + store["household"], weights="household_weight" + ), + ) + + def __repr__(self) -> str: + if self.data is None: + return f"" + else: + n_people = len(self.data.person) + n_benunits = len(self.data.benunit) + n_households = len(self.data.household) + return f"" + + +def create_datasets( + datasets: list[str] = [ + "hf://policyengine/policyengine-uk-data/frs_2023_24.h5", + "hf://policyengine/policyengine-uk-data/enhanced_frs_2023_24.h5", + ], + years: list[int] = [2026, 2027, 2028, 2029, 2030], +) -> None: + for dataset in datasets: + from policyengine_uk import Microsimulation + sim = Microsimulation(dataset=dataset) + for year in years: + year_dataset = sim.dataset[year] + + # Convert to pandas DataFrames and add weight columns + person_df = pd.DataFrame(year_dataset.person) + benunit_df = pd.DataFrame(year_dataset.benunit) + household_df = pd.DataFrame(year_dataset.household) + + # Map household weights to person and benunit levels + person_df = person_df.merge( + household_df[["household_id", "household_weight"]], + left_on="person_household_id", + right_on="household_id", + how="left" + ) + person_df = person_df.rename(columns={"household_weight": "person_weight"}) + person_df = person_df.drop(columns=["household_id"]) + + # Get household_id for each benunit from person table + benunit_household_map = person_df[["person_benunit_id", "person_household_id"]].drop_duplicates() + benunit_df = benunit_df.merge( + benunit_household_map, + left_on="benunit_id", + right_on="person_benunit_id", + how="left" + ) + benunit_df = benunit_df.merge( + household_df[["household_id", "household_weight"]], + left_on="person_household_id", + right_on="household_id", + how="left" + ) + benunit_df = benunit_df.rename(columns={"household_weight": "benunit_weight"}) + benunit_df = benunit_df.drop(columns=["person_benunit_id", "person_household_id", "household_id"], errors="ignore") + + uk_dataset = PolicyEngineUKDataset( + name=f"{dataset}-year-{year}", + description=f"UK Dataset for year {year} based on {dataset}", + filepath=f"./data/{Path(dataset).stem}_year_{year}.h5", + year=year, + data=UKYearData( + person=MicroDataFrame(person_df, weights="person_weight"), + benunit=MicroDataFrame(benunit_df, weights="benunit_weight"), + household=MicroDataFrame(household_df, weights="household_weight"), + ), + ) + uk_dataset.save() diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py new file mode 100644 index 00000000..d6c0bf72 --- /dev/null +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -0,0 +1,255 @@ +from policyengine.core import TaxBenefitModel, TaxBenefitModelVersion, Variable, Parameter, ParameterValue +import datetime +import requests +from importlib.metadata import version +from policyengine.utils import parse_safe_date +import pandas as pd +from microdf import MicroDataFrame +from pathlib import Path +from .datasets import PolicyEngineUKDataset, UKYearData +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from policyengine.core.simulation import Simulation + + +class PolicyEngineUK(TaxBenefitModel): + id: str = "policyengine-uk" + description: str = "The UK's open-source dynamic tax and benefit microsimulation model maintained by PolicyEngine." + + +uk_model = PolicyEngineUK() + +pkg_version = version("policyengine-uk") + +# Get published time from PyPI +response = requests.get("https://pypi.org/pypi/policyengine-uk/json") +data = response.json() +upload_time = data["releases"][pkg_version][0]["upload_time_iso_8601"] + + +class PolicyEngineUKLatest(TaxBenefitModelVersion): + model: TaxBenefitModel = uk_model + version: str = pkg_version + created_at: datetime.datetime = datetime.datetime.fromisoformat( + upload_time + ) + + def __init__(self, **kwargs: dict): + super().__init__(**kwargs) + from policyengine_uk.system import system + from policyengine_core.enums import Enum + + self.id = f"{self.model.id}@{self.version}" + + self.variables = [] + for var_obj in system.variables.values(): + variable = Variable( + id=self.id + "-" + var_obj.name, + name=var_obj.name, + tax_benefit_model_version=self, + entity=var_obj.entity.key, + description=var_obj.documentation, + data_type=var_obj.value_type + if var_obj.value_type is not Enum + else str, + ) + if ( + hasattr(var_obj, "possible_values") + and var_obj.possible_values is not None + ): + variable.possible_values = list( + map( + lambda x: x.name, + var_obj.possible_values._value2member_map_.values(), + ) + ) + self.variables.append(variable) + + self.parameters = [] + from policyengine_core.parameters import Parameter as CoreParameter + + for param_node in system.parameters.get_descendants(): + if isinstance(param_node, CoreParameter): + parameter = Parameter( + id=self.id + "-" + param_node.name, + name=param_node.name, + tax_benefit_model_version=self, + description=param_node.description, + data_type=type( + param_node(2025) + ), # Example year to infer type + unit=param_node.metadata.get("unit"), + ) + self.parameters.append(parameter) + + for i in range(len(param_node.values_list)): + param_at_instant = param_node.values_list[i] + if i + 1 < len(param_node.values_list): + next_instant = param_node.values_list[i + 1] + else: + next_instant = None + parameter_value = ParameterValue( + parameter=parameter, + start_date=parse_safe_date( + param_at_instant.instant_str + ), + end_date=parse_safe_date(next_instant.instant_str) + if next_instant + else None, + value=param_at_instant.value, + ) + self.parameter_values.append(parameter_value) + + def run(self, simulation: "Simulation") -> "Simulation": + from policyengine_uk import Microsimulation + from policyengine_uk.data import UKSingleYearDataset + from policyengine.utils.parametric_reforms import simulation_modifier_from_parameter_values + + assert isinstance(simulation.dataset, PolicyEngineUKDataset) + + dataset = simulation.dataset + dataset.load() + input_data = UKSingleYearDataset( + person=dataset.data.person, + benunit=dataset.data.benunit, + household=dataset.data.household, + fiscal_year=dataset.year, + ) + microsim = Microsimulation(dataset=input_data) + + if ( + simulation.policy + and simulation.policy.simulation_modifier is not None + ): + simulation.policy.simulation_modifier(microsim) + elif simulation.policy: + modifier = simulation_modifier_from_parameter_values( + simulation.policy.parameter_values + ) + modifier(microsim) + + if ( + simulation.dynamic + and simulation.dynamic.simulation_modifier is not None + ): + simulation.dynamic.simulation_modifier(microsim) + elif simulation.dynamic: + modifier = simulation_modifier_from_parameter_values( + simulation.dynamic.parameter_values + ) + modifier(microsim) + + # Allow custom variable selection, or use defaults + if simulation.variables is not None: + entity_variables = simulation.variables + else: + # Default comprehensive variable set + entity_variables = { + "person": [ + # IDs and weights + "person_id", + "benunit_id", + "household_id", + "person_weight", + # Demographics + "age", + "gender", + "is_adult", + "is_SP_age", + "is_child", + # Income + "employment_income", + "self_employment_income", + "pension_income", + "private_pension_income", + "savings_interest_income", + "dividend_income", + "property_income", + "total_income", + "earned_income", + # Benefits + "universal_credit", + "child_benefit", + "pension_credit", + "income_support", + "working_tax_credit", + "child_tax_credit", + # Tax + "income_tax", + "national_insurance", + ], + "benunit": [ + # IDs and weights + "benunit_id", + "benunit_weight", + # Structure + "family_type", + # Income and benefits + "universal_credit", + "child_benefit", + "working_tax_credit", + "child_tax_credit", + ], + "household": [ + # IDs and weights + "household_id", + "household_weight", + # Income measures + "household_net_income", + "hbai_household_net_income", + "equiv_hbai_household_net_income", + "household_market_income", + "household_gross_income", + # Benefits and tax + "household_benefits", + "household_tax", + "vat", + # Housing + "rent", + "council_tax", + "tenure_type", + ], + } + + data = { + "person": pd.DataFrame(), + "benunit": pd.DataFrame(), + "household": pd.DataFrame(), + } + + for entity, variables in entity_variables.items(): + for var in variables: + data[entity][var] = microsim.calculate( + var, period=simulation.dataset.year, map_to=entity + ).values + + data["person"] = MicroDataFrame( + data["person"], weights="person_weight" + ) + data["benunit"] = MicroDataFrame( + data["benunit"], weights="benunit_weight" + ) + data["household"] = MicroDataFrame( + data["household"], weights="household_weight" + ) + + simulation.output_dataset = PolicyEngineUKDataset( + name=dataset.name, + description=dataset.description, + filepath=str( + Path(simulation.dataset.filepath).parent + / (simulation.id + ".h5") + ), + year=simulation.dataset.year, + is_output_dataset=True, + data=UKYearData( + person=data["person"], + benunit=data["benunit"], + household=data["household"], + ), + ) + + simulation.output_dataset.save() + + +uk_latest = PolicyEngineUKLatest() diff --git a/src/policyengine/tax_benefit_models/uk/outputs.py b/src/policyengine/tax_benefit_models/uk/outputs.py new file mode 100644 index 00000000..ade22ee9 --- /dev/null +++ b/src/policyengine/tax_benefit_models/uk/outputs.py @@ -0,0 +1,103 @@ +"""UK-specific output templates.""" + +from policyengine.outputs.base import Output +from policyengine.outputs.aggregate import Aggregate, AggregateType +from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType +from pydantic import ConfigDict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from policyengine.core.simulation import Simulation + + +class ProgrammeStatistics(Output): + """Single programme's statistics from a policy reform - represents one database row.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + baseline_simulation: "Simulation" + reform_simulation: "Simulation" + programme_name: str + entity: str + is_tax: bool = False + + # Results populated by run() + baseline_total: float | None = None + reform_total: float | None = None + change: float | None = None + baseline_count: float | None = None + reform_count: float | None = None + winners: float | None = None + losers: float | None = None + + def run(self): + """Calculate statistics for this programme.""" + # Baseline totals + baseline_total = Aggregate( + simulation=self.baseline_simulation, + variable=self.programme_name, + aggregate_type=AggregateType.SUM, + entity=self.entity, + ) + baseline_total.run() + + # Reform totals + reform_total = Aggregate( + simulation=self.reform_simulation, + variable=self.programme_name, + aggregate_type=AggregateType.SUM, + entity=self.entity, + ) + reform_total.run() + + # Count of recipients/payers (baseline) + baseline_count = Aggregate( + simulation=self.baseline_simulation, + variable=self.programme_name, + aggregate_type=AggregateType.COUNT, + entity=self.entity, + filter_variable=self.programme_name, + filter_variable_geq=0.01, + ) + baseline_count.run() + + # Count of recipients/payers (reform) + reform_count = Aggregate( + simulation=self.reform_simulation, + variable=self.programme_name, + aggregate_type=AggregateType.COUNT, + entity=self.entity, + filter_variable=self.programme_name, + filter_variable_geq=0.01, + ) + reform_count.run() + + # Winners and losers + winners = ChangeAggregate( + baseline_simulation=self.baseline_simulation, + reform_simulation=self.reform_simulation, + variable=self.programme_name, + aggregate_type=ChangeAggregateType.COUNT, + entity=self.entity, + change_geq=0.01 if not self.is_tax else -0.01, + ) + winners.run() + + losers = ChangeAggregate( + baseline_simulation=self.baseline_simulation, + reform_simulation=self.reform_simulation, + variable=self.programme_name, + aggregate_type=ChangeAggregateType.COUNT, + entity=self.entity, + change_leq=-0.01 if not self.is_tax else 0.01, + ) + losers.run() + + # Populate results + self.baseline_total = float(baseline_total.result) + self.reform_total = float(reform_total.result) + self.change = float(reform_total.result - baseline_total.result) + self.baseline_count = float(baseline_count.result) + self.reform_count = float(reform_count.result) + self.winners = float(winners.result) + self.losers = float(losers.result) diff --git a/src/policyengine/utils/parametric_reforms.py b/src/policyengine/utils/parametric_reforms.py index 2c96b039..81c9fba3 100644 --- a/src/policyengine/utils/parametric_reforms.py +++ b/src/policyengine/utils/parametric_reforms.py @@ -1,5 +1,6 @@ from policyengine.core import ParameterValue from typing import Callable +from policyengine_core.periods import period def simulation_modifier_from_parameter_values(parameter_values: list[ParameterValue]) -> Callable: @@ -16,10 +17,12 @@ def simulation_modifier_from_parameter_values(parameter_values: list[ParameterVa def modifier(simulation): for pv in parameter_values: p = simulation.tax_benefit_system.parameters.get_child(pv.parameter.name) + start_period = period(pv.start_date.strftime("%Y-%m-%d")) + stop_period = period(pv.end_date.strftime("%Y-%m-%d")) if pv.end_date else None p.update( value=pv.value, - start=pv.start_date.strftime("%Y-%m-%d"), - stop=pv.stop_date.strftime("%Y-%m-%d") if pv.stop_date else None, + start=start_period, + stop=stop_period, ) return simulation From 055e8e0f347c6f51544ac1852d382530c9af54c1 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Thu, 13 Nov 2025 16:09:04 +0000 Subject: [PATCH 22/35] Add US basics --- src/policyengine/core/__init__.py | 1 + src/policyengine/core/output.py | 25 +++++ src/policyengine/outputs/__init__.py | 3 +- src/policyengine/outputs/aggregate.py | 1 - src/policyengine/outputs/base.py | 12 -- src/policyengine/outputs/change_aggregate.py | 1 - src/policyengine/outputs/decile_impact.py | 11 +- .../tax_benefit_models/uk/analysis.py | 32 ++++-- .../tax_benefit_models/uk/outputs.py | 2 +- src/policyengine/tax_benefit_models/us.py | 20 ++++ .../tax_benefit_models/us/__init__.py | 19 ++++ .../tax_benefit_models/us/analysis.py | 94 ++++++++++++++++ .../tax_benefit_models/us/model.py | 105 ++++++++++++++++++ .../tax_benefit_models/us/outputs.py | 103 +++++++++++++++++ 14 files changed, 397 insertions(+), 32 deletions(-) create mode 100644 src/policyengine/core/output.py delete mode 100644 src/policyengine/outputs/base.py create mode 100644 src/policyengine/tax_benefit_models/us.py create mode 100644 src/policyengine/tax_benefit_models/us/__init__.py create mode 100644 src/policyengine/tax_benefit_models/us/analysis.py create mode 100644 src/policyengine/tax_benefit_models/us/model.py create mode 100644 src/policyengine/tax_benefit_models/us/outputs.py diff --git a/src/policyengine/core/__init__.py b/src/policyengine/core/__init__.py index ef86cb16..f94a0c3b 100644 --- a/src/policyengine/core/__init__.py +++ b/src/policyengine/core/__init__.py @@ -8,6 +8,7 @@ from .policy import Policy from .simulation import Simulation from .dataset_version import DatasetVersion +from .output import Output, OutputCollection # Rebuild models to resolve forward references TaxBenefitModelVersion.model_rebuild() diff --git a/src/policyengine/core/output.py b/src/policyengine/core/output.py new file mode 100644 index 00000000..26418d69 --- /dev/null +++ b/src/policyengine/core/output.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel, ConfigDict +import pandas as pd +from typing import Generic, TypeVar + +T = TypeVar('T', bound='Output') + + +class Output(BaseModel): + """Base class for all output templates.""" + + def run(self): + """Calculate and populate the output fields. + + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement run()") + + +class OutputCollection(BaseModel, Generic[T]): + """Container for a collection of outputs with their DataFrame representation.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + outputs: list[T] + dataframe: pd.DataFrame diff --git a/src/policyengine/outputs/__init__.py b/src/policyengine/outputs/__init__.py index fc35bb27..3a873415 100644 --- a/src/policyengine/outputs/__init__.py +++ b/src/policyengine/outputs/__init__.py @@ -1,10 +1,11 @@ -from policyengine.outputs.base import Output +from policyengine.core import Output, OutputCollection from policyengine.outputs.aggregate import Aggregate, AggregateType from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType from policyengine.outputs.decile_impact import DecileImpact, calculate_decile_impacts __all__ = [ "Output", + "OutputCollection", "Aggregate", "AggregateType", "ChangeAggregate", diff --git a/src/policyengine/outputs/aggregate.py b/src/policyengine/outputs/aggregate.py index 42e408c2..120b125a 100644 --- a/src/policyengine/outputs/aggregate.py +++ b/src/policyengine/outputs/aggregate.py @@ -1,5 +1,4 @@ from policyengine.core import * -from policyengine.outputs.base import Output from enum import Enum from typing import Any diff --git a/src/policyengine/outputs/base.py b/src/policyengine/outputs/base.py deleted file mode 100644 index 46e2b46c..00000000 --- a/src/policyengine/outputs/base.py +++ /dev/null @@ -1,12 +0,0 @@ -from pydantic import BaseModel - - -class Output(BaseModel): - """Base class for all output templates.""" - - def run(self): - """Calculate and populate the output fields. - - Must be implemented by subclasses. - """ - raise NotImplementedError("Subclasses must implement run()") diff --git a/src/policyengine/outputs/change_aggregate.py b/src/policyengine/outputs/change_aggregate.py index 89975abf..f3ab7e78 100644 --- a/src/policyengine/outputs/change_aggregate.py +++ b/src/policyengine/outputs/change_aggregate.py @@ -1,5 +1,4 @@ from policyengine.core import * -from policyengine.outputs.base import Output from enum import Enum from typing import Any diff --git a/src/policyengine/outputs/decile_impact.py b/src/policyengine/outputs/decile_impact.py index 5249b2a8..b1a16359 100644 --- a/src/policyengine/outputs/decile_impact.py +++ b/src/policyengine/outputs/decile_impact.py @@ -1,5 +1,4 @@ -from policyengine.core import Simulation -from policyengine.outputs.base import Output +from policyengine.core import Simulation, Output, OutputCollection from pydantic import ConfigDict import pandas as pd @@ -82,11 +81,11 @@ def calculate_decile_impacts( income_variable: str = "equiv_hbai_household_net_income", entity: str | None = None, quantiles: int = 10, -) -> tuple[list[DecileImpact], pd.DataFrame]: +) -> OutputCollection[DecileImpact]: """Calculate decile-by-decile impact of a reform. Returns: - tuple of (list of DecileImpact objects, DataFrame) + OutputCollection containing list of DecileImpact objects and DataFrame """ results = [] for decile in range(1, quantiles + 1): @@ -101,7 +100,7 @@ def calculate_decile_impacts( impact.run() results.append(impact) - # Also create DataFrame for convenience + # Create DataFrame df = pd.DataFrame([ { "baseline_simulation_id": r.baseline_simulation.id, @@ -119,4 +118,4 @@ def calculate_decile_impacts( for r in results ]) - return results, df + return OutputCollection(outputs=results, dataframe=df) diff --git a/src/policyengine/tax_benefit_models/uk/analysis.py b/src/policyengine/tax_benefit_models/uk/analysis.py index f29c3f47..18ab1f5f 100644 --- a/src/policyengine/tax_benefit_models/uk/analysis.py +++ b/src/policyengine/tax_benefit_models/uk/analysis.py @@ -1,26 +1,30 @@ """General utility functions for UK policy reform analysis.""" -from policyengine.core import Simulation +from policyengine.core import Simulation, OutputCollection from policyengine.outputs.decile_impact import DecileImpact, calculate_decile_impacts from .outputs import ProgrammeStatistics +from pydantic import BaseModel import pandas as pd +class PolicyReformAnalysis(BaseModel): + """Complete policy reform analysis result.""" + + decile_impacts: OutputCollection[DecileImpact] + programme_statistics: OutputCollection[ProgrammeStatistics] + + def general_policy_reform_analysis( baseline_simulation: Simulation, reform_simulation: Simulation, -) -> tuple[list[DecileImpact], list[ProgrammeStatistics], pd.DataFrame, pd.DataFrame]: +) -> PolicyReformAnalysis: """Perform comprehensive analysis of a policy reform. Returns: - tuple of: - - list[DecileImpact]: Decile-by-decile impacts - - list[ProgrammeStatistics]: Statistics for major programmes - - pd.DataFrame: Decile impacts as DataFrame - - pd.DataFrame: Programme statistics as DataFrame + PolicyReformAnalysis containing decile impacts and programme statistics """ # Decile impact - decile_impacts, decile_df = calculate_decile_impacts( + decile_impacts = calculate_decile_impacts( baseline_simulation=baseline_simulation, reform_simulation=reform_simulation, ) @@ -57,7 +61,7 @@ def general_policy_reform_analysis( stats.run() programme_statistics.append(stats) - # Create DataFrames for convenience + # Create DataFrame programme_df = pd.DataFrame([ { "baseline_simulation_id": p.baseline_simulation.id, @@ -76,4 +80,12 @@ def general_policy_reform_analysis( for p in programme_statistics ]) - return decile_impacts, programme_statistics, decile_df, programme_df + programme_collection = OutputCollection( + outputs=programme_statistics, + dataframe=programme_df + ) + + return PolicyReformAnalysis( + decile_impacts=decile_impacts, + programme_statistics=programme_collection + ) diff --git a/src/policyengine/tax_benefit_models/uk/outputs.py b/src/policyengine/tax_benefit_models/uk/outputs.py index ade22ee9..6e3d8742 100644 --- a/src/policyengine/tax_benefit_models/uk/outputs.py +++ b/src/policyengine/tax_benefit_models/uk/outputs.py @@ -1,6 +1,6 @@ """UK-specific output templates.""" -from policyengine.outputs.base import Output +from policyengine.core import Output from policyengine.outputs.aggregate import Aggregate, AggregateType from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType from pydantic import ConfigDict diff --git a/src/policyengine/tax_benefit_models/us.py b/src/policyengine/tax_benefit_models/us.py new file mode 100644 index 00000000..074dc9d9 --- /dev/null +++ b/src/policyengine/tax_benefit_models/us.py @@ -0,0 +1,20 @@ +"""PolicyEngine US tax-benefit model - imports from us/ module.""" + +from importlib.util import find_spec + +if find_spec("policyengine_us") is not None: + from .us import * + + __all__ = [ + "PolicyEngineUS", + "PolicyEngineUSLatest", + "us_model", + "us_latest", + "general_policy_reform_analysis", + "ProgramStatistics", + ] + + # Rebuild models to resolve forward references + PolicyEngineUSLatest.model_rebuild() +else: + __all__ = [] diff --git a/src/policyengine/tax_benefit_models/us/__init__.py b/src/policyengine/tax_benefit_models/us/__init__.py new file mode 100644 index 00000000..3dd3f21b --- /dev/null +++ b/src/policyengine/tax_benefit_models/us/__init__.py @@ -0,0 +1,19 @@ +"""PolicyEngine US tax-benefit model.""" + +from importlib.util import find_spec + +if find_spec("policyengine_us") is not None: + from .model import PolicyEngineUS, PolicyEngineUSLatest, us_model, us_latest + from .analysis import general_policy_reform_analysis + from .outputs import ProgramStatistics + + __all__ = [ + "PolicyEngineUS", + "PolicyEngineUSLatest", + "us_model", + "us_latest", + "general_policy_reform_analysis", + "ProgramStatistics", + ] +else: + __all__ = [] diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py new file mode 100644 index 00000000..9a0d01f0 --- /dev/null +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -0,0 +1,94 @@ +"""General utility functions for US policy reform analysis.""" + +from policyengine.core import Simulation, OutputCollection +from policyengine.outputs.decile_impact import DecileImpact, calculate_decile_impacts +from .outputs import ProgramStatistics +from pydantic import BaseModel +import pandas as pd + + +class PolicyReformAnalysis(BaseModel): + """Complete policy reform analysis result.""" + + decile_impacts: OutputCollection[DecileImpact] + program_statistics: OutputCollection[ProgramStatistics] + + +def general_policy_reform_analysis( + baseline_simulation: Simulation, + reform_simulation: Simulation, +) -> PolicyReformAnalysis: + """Perform comprehensive analysis of a policy reform. + + Returns: + PolicyReformAnalysis containing decile impacts and program statistics + """ + # Decile impact (using household_net_income for US) + decile_impacts = calculate_decile_impacts( + baseline_simulation=baseline_simulation, + reform_simulation=reform_simulation, + income_variable="household_net_income", + ) + + # Major programs to analyse + programs = { + # Federal taxes + "income_tax": {"entity": "tax_unit", "is_tax": True}, + "payroll_tax": {"entity": "person", "is_tax": True}, + # State and local taxes + "state_income_tax": {"entity": "tax_unit", "is_tax": True}, + # Benefits + "snap": {"entity": "spm_unit", "is_tax": False}, + "tanf": {"entity": "spm_unit", "is_tax": False}, + "ssi": {"entity": "person", "is_tax": False}, + "social_security": {"entity": "person", "is_tax": False}, + "medicare": {"entity": "person", "is_tax": False}, + "medicaid": {"entity": "person", "is_tax": False}, + "eitc": {"entity": "tax_unit", "is_tax": False}, + "ctc": {"entity": "tax_unit", "is_tax": False}, + } + + program_statistics = [] + + for program_name, program_info in programs.items(): + entity = program_info["entity"] + is_tax = program_info["is_tax"] + + stats = ProgramStatistics( + baseline_simulation=baseline_simulation, + reform_simulation=reform_simulation, + program_name=program_name, + entity=entity, + is_tax=is_tax, + ) + stats.run() + program_statistics.append(stats) + + # Create DataFrame + program_df = pd.DataFrame([ + { + "baseline_simulation_id": p.baseline_simulation.id, + "reform_simulation_id": p.reform_simulation.id, + "program_name": p.program_name, + "entity": p.entity, + "is_tax": p.is_tax, + "baseline_total": p.baseline_total, + "reform_total": p.reform_total, + "change": p.change, + "baseline_count": p.baseline_count, + "reform_count": p.reform_count, + "winners": p.winners, + "losers": p.losers, + } + for p in program_statistics + ]) + + program_collection = OutputCollection( + outputs=program_statistics, + dataframe=program_df + ) + + return PolicyReformAnalysis( + decile_impacts=decile_impacts, + program_statistics=program_collection + ) diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py new file mode 100644 index 00000000..cae04d97 --- /dev/null +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -0,0 +1,105 @@ +from policyengine.core import TaxBenefitModel, TaxBenefitModelVersion, Variable, Parameter, ParameterValue +import datetime +import requests +from importlib.metadata import version +from policyengine.utils import parse_safe_date +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from policyengine.core.simulation import Simulation + + +class PolicyEngineUS(TaxBenefitModel): + id: str = "policyengine-us" + description: str = "The US's open-source dynamic tax and benefit microsimulation model maintained by PolicyEngine." + + +us_model = PolicyEngineUS() + +pkg_version = version("policyengine-us") + +# Get published time from PyPI +response = requests.get("https://pypi.org/pypi/policyengine-us/json") +data = response.json() +upload_time = data["releases"][pkg_version][0]["upload_time_iso_8601"] + + +class PolicyEngineUSLatest(TaxBenefitModelVersion): + model: TaxBenefitModel = us_model + version: str = pkg_version + created_at: datetime.datetime = datetime.datetime.fromisoformat( + upload_time + ) + + def __init__(self, **kwargs: dict): + super().__init__(**kwargs) + from policyengine_us.system import system + from policyengine_core.enums import Enum + + self.id = f"{self.model.id}@{self.version}" + + self.variables = [] + for var_obj in system.variables.values(): + variable = Variable( + id=self.id + "-" + var_obj.name, + name=var_obj.name, + tax_benefit_model_version=self, + entity=var_obj.entity.key, + description=var_obj.documentation, + data_type=var_obj.value_type + if var_obj.value_type is not Enum + else str, + ) + if ( + hasattr(var_obj, "possible_values") + and var_obj.possible_values is not None + ): + variable.possible_values = list( + map( + lambda x: x.name, + var_obj.possible_values._value2member_map_.values(), + ) + ) + self.variables.append(variable) + + self.parameters = [] + from policyengine_core.parameters import Parameter as CoreParameter + + for param_node in system.parameters.get_descendants(): + if isinstance(param_node, CoreParameter): + parameter = Parameter( + id=self.id + "-" + param_node.name, + name=param_node.name, + tax_benefit_model_version=self, + description=param_node.description, + data_type=type( + param_node(2025) + ), + unit=param_node.metadata.get("unit"), + ) + self.parameters.append(parameter) + + for i in range(len(param_node.values_list)): + param_at_instant = param_node.values_list[i] + if i + 1 < len(param_node.values_list): + next_instant = param_node.values_list[i + 1] + else: + next_instant = None + parameter_value = ParameterValue( + parameter=parameter, + start_date=parse_safe_date( + param_at_instant.instant_str + ), + end_date=parse_safe_date(next_instant.instant_str) + if next_instant + else None, + value=param_at_instant.value, + ) + self.parameter_values.append(parameter_value) + + def run(self, simulation: "Simulation") -> "Simulation": + """Run simulation - implementation depends on US dataset structure.""" + raise NotImplementedError("US simulation runner not yet implemented - pending dataset implementation") + + +us_latest = PolicyEngineUSLatest() diff --git a/src/policyengine/tax_benefit_models/us/outputs.py b/src/policyengine/tax_benefit_models/us/outputs.py new file mode 100644 index 00000000..dd1131cf --- /dev/null +++ b/src/policyengine/tax_benefit_models/us/outputs.py @@ -0,0 +1,103 @@ +"""US-specific output templates.""" + +from policyengine.core import Output +from policyengine.outputs.aggregate import Aggregate, AggregateType +from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType +from pydantic import ConfigDict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from policyengine.core.simulation import Simulation + + +class ProgramStatistics(Output): + """Single program's statistics from a policy reform - represents one database row.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + baseline_simulation: "Simulation" + reform_simulation: "Simulation" + program_name: str + entity: str + is_tax: bool = False + + # Results populated by run() + baseline_total: float | None = None + reform_total: float | None = None + change: float | None = None + baseline_count: float | None = None + reform_count: float | None = None + winners: float | None = None + losers: float | None = None + + def run(self): + """Calculate statistics for this program.""" + # Baseline totals + baseline_total = Aggregate( + simulation=self.baseline_simulation, + variable=self.program_name, + aggregate_type=AggregateType.SUM, + entity=self.entity, + ) + baseline_total.run() + + # Reform totals + reform_total = Aggregate( + simulation=self.reform_simulation, + variable=self.program_name, + aggregate_type=AggregateType.SUM, + entity=self.entity, + ) + reform_total.run() + + # Count of recipients/payers (baseline) + baseline_count = Aggregate( + simulation=self.baseline_simulation, + variable=self.program_name, + aggregate_type=AggregateType.COUNT, + entity=self.entity, + filter_variable=self.program_name, + filter_variable_geq=0.01, + ) + baseline_count.run() + + # Count of recipients/payers (reform) + reform_count = Aggregate( + simulation=self.reform_simulation, + variable=self.program_name, + aggregate_type=AggregateType.COUNT, + entity=self.entity, + filter_variable=self.program_name, + filter_variable_geq=0.01, + ) + reform_count.run() + + # Winners and losers + winners = ChangeAggregate( + baseline_simulation=self.baseline_simulation, + reform_simulation=self.reform_simulation, + variable=self.program_name, + aggregate_type=ChangeAggregateType.COUNT, + entity=self.entity, + change_geq=0.01 if not self.is_tax else -0.01, + ) + winners.run() + + losers = ChangeAggregate( + baseline_simulation=self.baseline_simulation, + reform_simulation=self.reform_simulation, + variable=self.program_name, + aggregate_type=ChangeAggregateType.COUNT, + entity=self.entity, + change_leq=-0.01 if not self.is_tax else 0.01, + ) + losers.run() + + # Populate results + self.baseline_total = float(baseline_total.result) + self.reform_total = float(reform_total.result) + self.change = float(reform_total.result - baseline_total.result) + self.baseline_count = float(baseline_count.result) + self.reform_count = float(reform_count.result) + self.winners = float(winners.result) + self.losers = float(losers.result) From 1f272aec9ae9dc4361493dcbac565cfa16e3a548 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Thu, 13 Nov 2025 17:00:12 +0000 Subject: [PATCH 23/35] Add us fixes --- src/policyengine/tax_benefit_models/us.py | 3 + .../tax_benefit_models/us/__init__.py | 3 + .../tax_benefit_models/us/datasets.py | 168 ++++++++++++++++++ .../tax_benefit_models/us/model.py | 160 ++++++++++++++++- 4 files changed, 332 insertions(+), 2 deletions(-) create mode 100644 src/policyengine/tax_benefit_models/us/datasets.py diff --git a/src/policyengine/tax_benefit_models/us.py b/src/policyengine/tax_benefit_models/us.py index 074dc9d9..50c7b063 100644 --- a/src/policyengine/tax_benefit_models/us.py +++ b/src/policyengine/tax_benefit_models/us.py @@ -6,6 +6,8 @@ from .us import * __all__ = [ + "USYearData", + "PolicyEngineUSDataset", "PolicyEngineUS", "PolicyEngineUSLatest", "us_model", @@ -15,6 +17,7 @@ ] # Rebuild models to resolve forward references + PolicyEngineUSDataset.model_rebuild() PolicyEngineUSLatest.model_rebuild() else: __all__ = [] diff --git a/src/policyengine/tax_benefit_models/us/__init__.py b/src/policyengine/tax_benefit_models/us/__init__.py index 3dd3f21b..7e8a6a4c 100644 --- a/src/policyengine/tax_benefit_models/us/__init__.py +++ b/src/policyengine/tax_benefit_models/us/__init__.py @@ -3,11 +3,14 @@ from importlib.util import find_spec if find_spec("policyengine_us") is not None: + from .datasets import USYearData, PolicyEngineUSDataset from .model import PolicyEngineUS, PolicyEngineUSLatest, us_model, us_latest from .analysis import general_policy_reform_analysis from .outputs import ProgramStatistics __all__ = [ + "USYearData", + "PolicyEngineUSDataset", "PolicyEngineUS", "PolicyEngineUSLatest", "us_model", diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py new file mode 100644 index 00000000..b9bad8c5 --- /dev/null +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -0,0 +1,168 @@ +from policyengine.core import Dataset +from pydantic import BaseModel, ConfigDict +import pandas as pd +from microdf import MicroDataFrame +from pathlib import Path + + +class USYearData(BaseModel): + """Entity-level data for a single year.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + person: MicroDataFrame + marital_unit: MicroDataFrame + family: MicroDataFrame + spm_unit: MicroDataFrame + tax_unit: MicroDataFrame + household: MicroDataFrame + + def map_to_entity( + self, source_entity: str, target_entity: str, columns: list[str] = None + ) -> MicroDataFrame: + """Map data from source entity to target entity using join keys. + + Args: + source_entity (str): The source entity name. + target_entity (str): The target entity name. + columns (list[str], optional): List of column names to map. If None, maps all columns. + + Returns: + MicroDataFrame: The mapped data at the target entity level. + + Raises: + ValueError: If source or target entity is invalid. + """ + valid_entities = {"person", "marital_unit", "family", "spm_unit", "tax_unit", "household"} + if source_entity not in valid_entities: + raise ValueError( + f"Invalid source entity '{source_entity}'. Must be one of {valid_entities}" + ) + if target_entity not in valid_entities: + raise ValueError( + f"Invalid target entity '{target_entity}'. Must be one of {valid_entities}" + ) + + # Get source data + source_df = getattr(self, source_entity) + if columns: + # Select only requested columns (keep join keys) + join_keys = {"person_id", "marital_unit_id", "family_id", "spm_unit_id", "tax_unit_id", "household_id"} + cols_to_keep = list( + set(columns) | (join_keys & set(source_df.columns)) + ) + source_df = source_df[cols_to_keep] + + # Determine weight column for target entity + weight_col_map = { + "person": "person_weight", + "marital_unit": "marital_unit_weight", + "family": "family_weight", + "spm_unit": "spm_unit_weight", + "tax_unit": "tax_unit_weight", + "household": "household_weight", + } + target_weight = weight_col_map[target_entity] + + # Same entity - return as is + if source_entity == target_entity: + return MicroDataFrame( + pd.DataFrame(source_df), weights=target_weight + ) + + # Map to different entity + target_df = getattr(self, target_entity) + + # Direct mapping if join key exists in source + target_key = f"{target_entity}_id" + if target_key in pd.DataFrame(source_df).columns: + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on=target_key, how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # For more complex mappings, go through person table + if source_entity != "person" and target_entity != "person": + # Get person link table with both entity IDs + person_df = pd.DataFrame(self.person) + source_key = f"{source_entity}_id" + + # Link source -> person -> target + if source_key in person_df.columns and target_key in person_df.columns: + person_link = person_df[[source_key, target_key]].drop_duplicates() + source_with_target = pd.DataFrame(source_df).merge( + person_link, on=source_key, how="left" + ) + result = pd.DataFrame(target_df).merge( + source_with_target, on=target_key, how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + raise ValueError( + f"Unsupported mapping from {source_entity} to {target_entity}" + ) + + +class PolicyEngineUSDataset(Dataset): + """US dataset with multi-year entity-level data.""" + + data: USYearData | None = None + + def __init__(self, **kwargs: dict): + super().__init__(**kwargs) + + # Make sure we are synchronised between in-memory and storage, at least on initialisation + if "data" in kwargs: + self.save() + elif "filepath" in kwargs: + self.load() + + def save(self) -> None: + """Save dataset to HDF5 file.""" + filepath = Path(self.filepath) + if not filepath.parent.exists(): + filepath.parent.mkdir(parents=True, exist_ok=True) + with pd.HDFStore(filepath, mode="w") as store: + store["person"] = pd.DataFrame(self.data.person) + store["marital_unit"] = pd.DataFrame(self.data.marital_unit) + store["family"] = pd.DataFrame(self.data.family) + store["spm_unit"] = pd.DataFrame(self.data.spm_unit) + store["tax_unit"] = pd.DataFrame(self.data.tax_unit) + store["household"] = pd.DataFrame(self.data.household) + + def load(self) -> None: + """Load dataset from HDF5 file into this instance.""" + filepath = self.filepath + with pd.HDFStore(filepath, mode="r") as store: + self.data = USYearData( + person=MicroDataFrame( + store["person"], weights="person_weight" + ), + marital_unit=MicroDataFrame( + store["marital_unit"], weights="marital_unit_weight" + ), + family=MicroDataFrame( + store["family"], weights="family_weight" + ), + spm_unit=MicroDataFrame( + store["spm_unit"], weights="spm_unit_weight" + ), + tax_unit=MicroDataFrame( + store["tax_unit"], weights="tax_unit_weight" + ), + household=MicroDataFrame( + store["household"], weights="household_weight" + ), + ) + + def __repr__(self) -> str: + if self.data is None: + return f"" + else: + n_people = len(self.data.person) + n_marital_units = len(self.data.marital_unit) + n_families = len(self.data.family) + n_spm_units = len(self.data.spm_unit) + n_tax_units = len(self.data.tax_unit) + n_households = len(self.data.household) + return f"" diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index cae04d97..676ade1b 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -3,6 +3,10 @@ import requests from importlib.metadata import version from policyengine.utils import parse_safe_date +import pandas as pd +from microdf import MicroDataFrame +from pathlib import Path +from .datasets import PolicyEngineUSDataset, USYearData from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -98,8 +102,160 @@ def __init__(self, **kwargs: dict): self.parameter_values.append(parameter_value) def run(self, simulation: "Simulation") -> "Simulation": - """Run simulation - implementation depends on US dataset structure.""" - raise NotImplementedError("US simulation runner not yet implemented - pending dataset implementation") + from policyengine_us import Microsimulation + from policyengine.utils.parametric_reforms import simulation_modifier_from_parameter_values + + assert isinstance(simulation.dataset, PolicyEngineUSDataset) + + dataset = simulation.dataset + dataset.load() + microsim = Microsimulation(dataset=None) + + if ( + simulation.policy + and simulation.policy.simulation_modifier is not None + ): + simulation.policy.simulation_modifier(microsim) + elif simulation.policy: + modifier = simulation_modifier_from_parameter_values( + simulation.policy.parameter_values + ) + modifier(microsim) + + if ( + simulation.dynamic + and simulation.dynamic.simulation_modifier is not None + ): + simulation.dynamic.simulation_modifier(microsim) + elif simulation.dynamic: + modifier = simulation_modifier_from_parameter_values( + simulation.dynamic.parameter_values + ) + modifier(microsim) + + # Allow custom variable selection, or use defaults + if simulation.variables is not None: + entity_variables = simulation.variables + else: + # Default comprehensive variable set + entity_variables = { + "person": [ + # IDs and weights + "person_id", + "marital_unit_id", + "family_id", + "spm_unit_id", + "tax_unit_id", + "household_id", + "person_weight", + # Demographics + "age", + "gender", + "is_adult", + "is_child", + # Income + "employment_income", + "self_employment_income", + "pension_income", + "social_security", + "ssi", + # Benefits + "snap", + "tanf", + "medicare", + "medicaid", + # Tax + "payroll_tax", + ], + "marital_unit": [ + "marital_unit_id", + "marital_unit_weight", + ], + "family": [ + "family_id", + "family_weight", + ], + "spm_unit": [ + "spm_unit_id", + "spm_unit_weight", + "snap", + "tanf", + "spm_unit_net_income", + ], + "tax_unit": [ + "tax_unit_id", + "tax_unit_weight", + "income_tax", + "payroll_tax", + "state_income_tax", + "eitc", + "ctc", + "adjusted_gross_income", + ], + "household": [ + "household_id", + "household_weight", + "household_net_income", + "household_benefits", + "household_tax", + "household_market_income", + ], + } + + data = { + "person": pd.DataFrame(), + "marital_unit": pd.DataFrame(), + "family": pd.DataFrame(), + "spm_unit": pd.DataFrame(), + "tax_unit": pd.DataFrame(), + "household": pd.DataFrame(), + } + + for entity, variables in entity_variables.items(): + for var in variables: + data[entity][var] = microsim.calculate( + var, period=simulation.dataset.year, map_to=entity + ).values + + data["person"] = MicroDataFrame( + data["person"], weights="person_weight" + ) + data["marital_unit"] = MicroDataFrame( + data["marital_unit"], weights="marital_unit_weight" + ) + data["family"] = MicroDataFrame( + data["family"], weights="family_weight" + ) + data["spm_unit"] = MicroDataFrame( + data["spm_unit"], weights="spm_unit_weight" + ) + data["tax_unit"] = MicroDataFrame( + data["tax_unit"], weights="tax_unit_weight" + ) + data["household"] = MicroDataFrame( + data["household"], weights="household_weight" + ) + + simulation.output_dataset = PolicyEngineUSDataset( + name=dataset.name, + description=dataset.description, + filepath=str( + Path(simulation.dataset.filepath).parent + / (simulation.id + ".h5") + ), + year=simulation.dataset.year, + is_output_dataset=True, + data=USYearData( + person=data["person"], + marital_unit=data["marital_unit"], + family=data["family"], + spm_unit=data["spm_unit"], + tax_unit=data["tax_unit"], + household=data["household"], + ), + ) + + simulation.output_dataset.save() us_latest = PolicyEngineUSLatest() From e8edb7facb2e908b68989bdd2ed1143303c1c2d5 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 16 Nov 2025 18:49:41 +0000 Subject: [PATCH 24/35] Add household analysis example --- examples/employment_income_variation.py | 279 ++++++++ examples/income_bands.py | 4 +- examples/policy_change.py | 10 +- src/policyengine/core/__init__.py | 1 + src/policyengine/core/dataset.py | 2 +- src/policyengine/core/output.py | 2 +- src/policyengine/outputs/__init__.py | 10 +- src/policyengine/outputs/aggregate.py | 20 +- src/policyengine/outputs/change_aggregate.py | 95 ++- src/policyengine/outputs/decile_impact.py | 66 +- src/policyengine/tax_benefit_models/uk.py | 4 + .../tax_benefit_models/uk/__init__.py | 8 + .../tax_benefit_models/uk/analysis.py | 46 +- .../tax_benefit_models/uk/datasets.py | 51 +- .../tax_benefit_models/uk/model.py | 13 +- .../tax_benefit_models/uk/outputs.py | 5 +- .../tax_benefit_models/us/__init__.py | 7 +- .../tax_benefit_models/us/analysis.py | 47 +- .../tax_benefit_models/us/datasets.py | 27 +- .../tax_benefit_models/us/model.py | 16 +- .../tax_benefit_models/us/outputs.py | 5 +- src/policyengine/utils/parametric_reforms.py | 16 +- tests/test_change_aggregate.py | 615 +++++++++++------- 23 files changed, 960 insertions(+), 389 deletions(-) create mode 100644 examples/employment_income_variation.py diff --git a/examples/employment_income_variation.py b/examples/employment_income_variation.py new file mode 100644 index 00000000..7fa02b7a --- /dev/null +++ b/examples/employment_income_variation.py @@ -0,0 +1,279 @@ +"""Example: Vary employment income and plot HBAI household net income. + +This script demonstrates: +1. Creating a custom dataset with a single household template +2. Varying employment income from £0 to £100k +3. Running a single simulation for all variations +4. Using Aggregate with filters to extract results by employment income +5. Visualising the relationship between employment income and net income + +Run: python examples/employment_income_variation.py +""" + +import pandas as pd +import tempfile +from pathlib import Path +import plotly.graph_objects as go +from microdf import MicroDataFrame +from policyengine.core import Simulation +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, + uk_latest, +) +from policyengine.outputs.aggregate import Aggregate, AggregateType + + +def create_dataset_with_varied_employment_income( + employment_incomes: list[float], year: int = 2026 +) -> PolicyEngineUKDataset: + """Create a dataset with one household template, varied by employment income. + + Each household is a single adult with varying employment income. + Everything else about the household is kept constant. + """ + n_households = len(employment_incomes) + + # Create person data - one adult per household with varying employment income + person_data = { + "person_id": list(range(n_households)), + "person_benunit_id": list(range(n_households)), # Link to benunit + "person_household_id": list(range(n_households)), # Link to household + "age": [35] * n_households, # Single adult, age 35 + "employment_income": employment_incomes, + "person_weight": [1.0] * n_households, + } + + # Create benunit data - one per household + benunit_data = { + "benunit_id": list(range(n_households)), + "benunit_weight": [1.0] * n_households, + } + + # Create household data - one per employment income level + household_data = { + "household_id": list(range(n_households)), + "household_weight": [1.0] * n_households, + "region": ["LONDON"] * n_households, # Required by policyengine-uk + "council_tax": [0.0] * n_households, # Simplified - no council tax + "rent": [0.0] * n_households, # Simplified - no rent + "tenure_type": ["RENT_PRIVATELY"] + * n_households, # Required for uprating + } + + # Create MicroDataFrames + person_df = MicroDataFrame( + pd.DataFrame(person_data), weights="person_weight" + ) + benunit_df = MicroDataFrame( + pd.DataFrame(benunit_data), weights="benunit_weight" + ) + household_df = MicroDataFrame( + pd.DataFrame(household_data), weights="household_weight" + ) + + # Create temporary file + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "employment_income_variation.h5") + + # Create dataset + dataset = PolicyEngineUKDataset( + name="Employment income variation", + description="Single adult household with varying employment income", + filepath=filepath, + year=year, + data=UKYearData( + person=person_df, + benunit=benunit_df, + household=household_df, + ), + ) + + return dataset + + +def run_simulation(dataset: PolicyEngineUKDataset) -> Simulation: + """Run a single simulation for all employment income variations.""" + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + ) + simulation.run() + return simulation + + +def extract_results_by_employment_income( + simulation: Simulation, employment_incomes: list[float] +) -> dict: + """Extract HBAI household net income and components for each employment income level. + + Uses Aggregate with filters to extract specific households. + """ + hbai_net_income = [] + household_benefits = [] + household_tax = [] + employment_income_hh = [] + + for emp_income in employment_incomes: + # Get HBAI household net income + agg = Aggregate( + simulation=simulation, + variable="hbai_household_net_income", + aggregate_type=AggregateType.MEAN, + filter_variable="employment_income", + filter_variable_eq=emp_income, + entity="household", + ) + agg.run() + hbai_net_income.append(agg.result) + + # Get household benefits + agg = Aggregate( + simulation=simulation, + variable="household_benefits", + aggregate_type=AggregateType.MEAN, + filter_variable="employment_income", + filter_variable_eq=emp_income, + entity="household", + ) + agg.run() + household_benefits.append(agg.result) + + # Get household tax + agg = Aggregate( + simulation=simulation, + variable="household_tax", + aggregate_type=AggregateType.MEAN, + filter_variable="employment_income", + filter_variable_eq=emp_income, + entity="household", + ) + agg.run() + household_tax.append(agg.result) + + # Get employment income at household level + agg = Aggregate( + simulation=simulation, + variable="employment_income", + aggregate_type=AggregateType.MEAN, + filter_variable="employment_income", + filter_variable_eq=emp_income, + entity="household", + ) + agg.run() + employment_income_hh.append(agg.result) + + return { + "employment_income": employment_incomes, + "hbai_household_net_income": hbai_net_income, + "household_benefits": household_benefits, + "household_tax": household_tax, + "employment_income_hh": employment_income_hh, + } + + +def visualise_results(results: dict) -> None: + """Create a line chart showing HBAI household net income and components.""" + fig = go.Figure() + + # Main HBAI net income line + fig.add_trace( + go.Scatter( + x=results["employment_income"], + y=results["hbai_household_net_income"], + mode="lines+markers", + name="HBAI household net income", + line=dict(color="darkblue", width=3), + marker=dict(size=5), + ) + ) + + # Employment income (gross) + fig.add_trace( + go.Scatter( + x=results["employment_income"], + y=results["employment_income_hh"], + mode="lines", + name="Employment income (gross)", + line=dict(color="green", width=2, dash="dot"), + ) + ) + + # Household benefits + fig.add_trace( + go.Scatter( + x=results["employment_income"], + y=results["household_benefits"], + mode="lines", + name="Household benefits", + line=dict(color="orange", width=2), + ) + ) + + # Household tax (negative for visual clarity) + fig.add_trace( + go.Scatter( + x=results["employment_income"], + y=[-t for t in results["household_tax"]], + mode="lines", + name="Household tax (negative)", + line=dict(color="red", width=2), + ) + ) + + fig.update_layout( + title="HBAI household net income and components by employment income", + xaxis_title="Employment income (£)", + yaxis_title="Amount (£)", + height=600, + width=1000, + showlegend=True, + legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), + ) + + fig.show() + + +def main(): + """Main execution function.""" + # Create employment income range from £0 to £100k + # Using smaller intervals at lower incomes where the relationship is more interesting + employment_incomes = ( + list(range(0, 20000, 1000)) # £0 to £20k in £1k steps + + list(range(20000, 50000, 2500)) # £20k to £50k in £2.5k steps + + list(range(50000, 100001, 5000)) # £50k to £100k in £5k steps + ) + + print( + f"Creating dataset with {len(employment_incomes)} employment income variations..." + ) + dataset = create_dataset_with_varied_employment_income(employment_incomes) + + print("Running simulation (single run for all variations)...") + simulation = run_simulation(dataset) + + print("Extracting results using aggregate filters...") + results = extract_results_by_employment_income( + simulation, employment_incomes + ) + + print("\nSample results:") + print( + f"Employment income £0: HBAI net income £{results['hbai_household_net_income'][0]:,.0f}" + ) + print( + f"Employment income £25k: HBAI net income £{results['hbai_household_net_income'][employment_incomes.index(25000)]:,.0f}" + ) + print( + f"Employment income £50k: HBAI net income £{results['hbai_household_net_income'][employment_incomes.index(50000)]:,.0f}" + ) + print( + f"Employment income £100k: HBAI net income £{results['hbai_household_net_income'][-1]:,.0f}" + ) + + print("\nGenerating visualisation...") + visualise_results(results) + + +if __name__ == "__main__": + main() diff --git a/examples/income_bands.py b/examples/income_bands.py index cd5819bd..06f22598 100644 --- a/examples/income_bands.py +++ b/examples/income_bands.py @@ -176,7 +176,9 @@ def main(): print(f"Total net income: £{total_net_income:.1f}bn") print(f"Total tax: £{total_tax:.1f}bn") print(f"Total households: {total_households:.1f}m") - print(f"Average effective tax rate: {total_tax / (total_net_income + total_tax) * 100:.1f}%") + print( + f"Average effective tax rate: {total_tax / (total_net_income + total_tax) * 100:.1f}%" + ) print("\nGenerating visualisations...") visualise_results(results) diff --git a/examples/policy_change.py b/examples/policy_change.py index 243515f7..574f37cd 100644 --- a/examples/policy_change.py +++ b/examples/policy_change.py @@ -268,11 +268,15 @@ def print_summary(overall: dict, decile: dict, reform_name: str) -> None: print(f" Losers: {overall['losers']:.2f}m households") print(f" No change: {overall['no_change']:.2f}m households") print(f"\nFinancial impact:") - print(f" Net income change: £{overall['total_change']:.2f}bn (negative = loss)") + print( + f" Net income change: £{overall['total_change']:.2f}bn (negative = loss)" + ) print(f" Tax revenue change: £{overall['tax_revenue_change']:.2f}bn") print(f"\nImpact by income decile:") - for i, label in enumerate(decile['labels']): - print(f" {label}: {decile['losers'][i]:.2f}m losers, avg change £{decile['avg_loss'][i]:.0f}") + for i, label in enumerate(decile["labels"]): + print( + f" {label}: {decile['losers'][i]:.2f}m losers, avg change £{decile['avg_loss'][i]:.0f}" + ) print("=" * 60) diff --git a/src/policyengine/core/__init__.py b/src/policyengine/core/__init__.py index f94a0c3b..58372d57 100644 --- a/src/policyengine/core/__init__.py +++ b/src/policyengine/core/__init__.py @@ -11,6 +11,7 @@ from .output import Output, OutputCollection # Rebuild models to resolve forward references +Dataset.model_rebuild() TaxBenefitModelVersion.model_rebuild() Variable.model_rebuild() Parameter.model_rebuild() diff --git a/src/policyengine/core/dataset.py b/src/policyengine/core/dataset.py index 34997d01..5ae76fa4 100644 --- a/src/policyengine/core/dataset.py +++ b/src/policyengine/core/dataset.py @@ -29,7 +29,7 @@ class MyDataset(Dataset): dataset_version: DatasetVersion | None = None filepath: str is_output_dataset: bool = False - tax_benefit_model: TaxBenefitModel = None + tax_benefit_model: TaxBenefitModel | None = None year: int data: BaseModel | None = None diff --git a/src/policyengine/core/output.py b/src/policyengine/core/output.py index 26418d69..874b694f 100644 --- a/src/policyengine/core/output.py +++ b/src/policyengine/core/output.py @@ -2,7 +2,7 @@ import pandas as pd from typing import Generic, TypeVar -T = TypeVar('T', bound='Output') +T = TypeVar("T", bound="Output") class Output(BaseModel): diff --git a/src/policyengine/outputs/__init__.py b/src/policyengine/outputs/__init__.py index 3a873415..8997578d 100644 --- a/src/policyengine/outputs/__init__.py +++ b/src/policyengine/outputs/__init__.py @@ -1,7 +1,13 @@ from policyengine.core import Output, OutputCollection from policyengine.outputs.aggregate import Aggregate, AggregateType -from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType -from policyengine.outputs.decile_impact import DecileImpact, calculate_decile_impacts +from policyengine.outputs.change_aggregate import ( + ChangeAggregate, + ChangeAggregateType, +) +from policyengine.outputs.decile_impact import ( + DecileImpact, + calculate_decile_impacts, +) __all__ = [ "Output", diff --git a/src/policyengine/outputs/aggregate.py b/src/policyengine/outputs/aggregate.py index 120b125a..29f18138 100644 --- a/src/policyengine/outputs/aggregate.py +++ b/src/policyengine/outputs/aggregate.py @@ -22,10 +22,16 @@ class Aggregate(Output): filter_variable_describes_quantiles: bool = False # Convenient quantile specification (alternative to describes_quantiles) - quantile: int | None = None # Number of quantiles (e.g., 10 for deciles, 5 for quintiles) + quantile: int | None = ( + None # Number of quantiles (e.g., 10 for deciles, 5 for quintiles) + ) quantile_eq: int | None = None # Exact quantile (e.g., 3 for 3rd decile) - quantile_leq: int | None = None # Maximum quantile (e.g., 5 for bottom 5 deciles) - quantile_geq: int | None = None # Minimum quantile (e.g., 9 for top 2 deciles) + quantile_leq: int | None = ( + None # Maximum quantile (e.g., 5 for bottom 5 deciles) + ) + quantile_geq: int | None = ( + None # Minimum quantile (e.g., 9 for top 2 deciles) + ) result: Any | None = None @@ -35,12 +41,16 @@ def run(self): self.filter_variable_describes_quantiles = True if self.quantile_eq is not None: # For a specific quantile, filter between (quantile-1)/n and quantile/n - self.filter_variable_geq = (self.quantile_eq - 1) / self.quantile + self.filter_variable_geq = ( + self.quantile_eq - 1 + ) / self.quantile self.filter_variable_leq = self.quantile_eq / self.quantile elif self.quantile_leq is not None: self.filter_variable_leq = self.quantile_leq / self.quantile elif self.quantile_geq is not None: - self.filter_variable_geq = (self.quantile_geq - 1) / self.quantile + self.filter_variable_geq = ( + self.quantile_geq - 1 + ) / self.quantile # Get variable object var_obj = next( diff --git a/src/policyengine/outputs/change_aggregate.py b/src/policyengine/outputs/change_aggregate.py index f3ab7e78..359ead57 100644 --- a/src/policyengine/outputs/change_aggregate.py +++ b/src/policyengine/outputs/change_aggregate.py @@ -19,12 +19,12 @@ class ChangeAggregate(Output): # Filter by absolute change change_geq: float | None = None # Change >= value (e.g., gain >= 500) change_leq: float | None = None # Change <= value (e.g., loss <= -500) - change_eq: float | None = None # Change == value + change_eq: float | None = None # Change == value # Filter by relative change (as decimal, e.g., 0.05 = 5%) relative_change_geq: float | None = None # Relative change >= value relative_change_leq: float | None = None # Relative change <= value - relative_change_eq: float | None = None # Relative change == value + relative_change_eq: float | None = None # Relative change == value # Filter by another variable (e.g., only count people with age >= 30) filter_variable: str | None = None @@ -34,10 +34,16 @@ class ChangeAggregate(Output): filter_variable_describes_quantiles: bool = False # Convenient quantile specification (alternative to describes_quantiles) - quantile: int | None = None # Number of quantiles (e.g., 10 for deciles, 5 for quintiles) + quantile: int | None = ( + None # Number of quantiles (e.g., 10 for deciles, 5 for quintiles) + ) quantile_eq: int | None = None # Exact quantile (e.g., 3 for 3rd decile) - quantile_leq: int | None = None # Maximum quantile (e.g., 5 for bottom 5 deciles) - quantile_geq: int | None = None # Minimum quantile (e.g., 9 for top 2 deciles) + quantile_leq: int | None = ( + None # Maximum quantile (e.g., 5 for bottom 5 deciles) + ) + quantile_geq: int | None = ( + None # Minimum quantile (e.g., 9 for top 2 deciles) + ) result: Any | None = None @@ -47,33 +53,46 @@ def run(self): self.filter_variable_describes_quantiles = True if self.quantile_eq is not None: # For a specific quantile, filter between (quantile-1)/n and quantile/n - self.filter_variable_geq = (self.quantile_eq - 1) / self.quantile + self.filter_variable_geq = ( + self.quantile_eq - 1 + ) / self.quantile self.filter_variable_leq = self.quantile_eq / self.quantile elif self.quantile_leq is not None: self.filter_variable_leq = self.quantile_leq / self.quantile elif self.quantile_geq is not None: - self.filter_variable_geq = (self.quantile_geq - 1) / self.quantile + self.filter_variable_geq = ( + self.quantile_geq - 1 + ) / self.quantile # Get variable object var_obj = next( - v for v in self.baseline_simulation.tax_benefit_model_version.variables + v + for v in self.baseline_simulation.tax_benefit_model_version.variables if v.name == self.variable ) # Get the target entity data target_entity = self.entity or var_obj.entity - baseline_data = getattr(self.baseline_simulation.output_dataset.data, target_entity) - reform_data = getattr(self.reform_simulation.output_dataset.data, target_entity) + baseline_data = getattr( + self.baseline_simulation.output_dataset.data, target_entity + ) + reform_data = getattr( + self.reform_simulation.output_dataset.data, target_entity + ) # Map variable to target entity if needed if var_obj.entity != target_entity: - baseline_mapped = self.baseline_simulation.output_dataset.data.map_to_entity( - var_obj.entity, target_entity + baseline_mapped = ( + self.baseline_simulation.output_dataset.data.map_to_entity( + var_obj.entity, target_entity + ) ) baseline_series = baseline_mapped[self.variable] - reform_mapped = self.reform_simulation.output_dataset.data.map_to_entity( - var_obj.entity, target_entity + reform_mapped = ( + self.reform_simulation.output_dataset.data.map_to_entity( + var_obj.entity, target_entity + ) ) reform_series = reform_mapped[self.variable] else: @@ -86,39 +105,45 @@ def run(self): # Calculate relative change (handling division by zero) # Where baseline is 0, relative change is undefined; we'll mask these out if relative filters are used import numpy as np - with np.errstate(divide='ignore', invalid='ignore'): + + with np.errstate(divide="ignore", invalid="ignore"): relative_change_series = change_series / baseline_series - relative_change_series = relative_change_series.replace([np.inf, -np.inf], np.nan) + relative_change_series = relative_change_series.replace( + [np.inf, -np.inf], np.nan + ) # Start with all rows mask = baseline_series.notna() # Apply absolute change filters if self.change_eq is not None: - mask &= (change_series == self.change_eq) + mask &= change_series == self.change_eq if self.change_leq is not None: - mask &= (change_series <= self.change_leq) + mask &= change_series <= self.change_leq if self.change_geq is not None: - mask &= (change_series >= self.change_geq) + mask &= change_series >= self.change_geq # Apply relative change filters if self.relative_change_eq is not None: - mask &= (relative_change_series == self.relative_change_eq) + mask &= relative_change_series == self.relative_change_eq if self.relative_change_leq is not None: - mask &= (relative_change_series <= self.relative_change_leq) + mask &= relative_change_series <= self.relative_change_leq if self.relative_change_geq is not None: - mask &= (relative_change_series >= self.relative_change_geq) + mask &= relative_change_series >= self.relative_change_geq # Apply filter_variable filters if self.filter_variable is not None: filter_var_obj = next( - v for v in self.baseline_simulation.tax_benefit_model_version.variables + v + for v in self.baseline_simulation.tax_benefit_model_version.variables if v.name == self.filter_variable ) if filter_var_obj.entity != target_entity: - filter_mapped = self.baseline_simulation.output_dataset.data.map_to_entity( - filter_var_obj.entity, target_entity + filter_mapped = ( + self.baseline_simulation.output_dataset.data.map_to_entity( + filter_var_obj.entity, target_entity + ) ) filter_series = filter_mapped[self.filter_variable] else: @@ -127,20 +152,24 @@ def run(self): if self.filter_variable_describes_quantiles: if self.filter_variable_eq is not None: threshold = filter_series.quantile(self.filter_variable_eq) - mask &= (filter_series <= threshold) + mask &= filter_series <= threshold if self.filter_variable_leq is not None: - threshold = filter_series.quantile(self.filter_variable_leq) - mask &= (filter_series <= threshold) + threshold = filter_series.quantile( + self.filter_variable_leq + ) + mask &= filter_series <= threshold if self.filter_variable_geq is not None: - threshold = filter_series.quantile(self.filter_variable_geq) - mask &= (filter_series >= threshold) + threshold = filter_series.quantile( + self.filter_variable_geq + ) + mask &= filter_series >= threshold else: if self.filter_variable_eq is not None: - mask &= (filter_series == self.filter_variable_eq) + mask &= filter_series == self.filter_variable_eq if self.filter_variable_leq is not None: - mask &= (filter_series <= self.filter_variable_leq) + mask &= filter_series <= self.filter_variable_leq if self.filter_variable_geq is not None: - mask &= (filter_series >= self.filter_variable_geq) + mask &= filter_series >= self.filter_variable_geq # Apply mask to get filtered data filtered_change = change_series[mask] diff --git a/src/policyengine/outputs/decile_impact.py b/src/policyengine/outputs/decile_impact.py index b1a16359..f2e7837f 100644 --- a/src/policyengine/outputs/decile_impact.py +++ b/src/policyengine/outputs/decile_impact.py @@ -37,18 +37,26 @@ def run(self): target_entity = self.entity or var_obj.entity # Get data from both simulations - baseline_data = getattr(self.baseline_simulation.output_dataset.data, target_entity) - reform_data = getattr(self.reform_simulation.output_dataset.data, target_entity) + baseline_data = getattr( + self.baseline_simulation.output_dataset.data, target_entity + ) + reform_data = getattr( + self.reform_simulation.output_dataset.data, target_entity + ) # Map income variable to target entity if needed if var_obj.entity != target_entity: - baseline_mapped = self.baseline_simulation.output_dataset.data.map_to_entity( - var_obj.entity, target_entity + baseline_mapped = ( + self.baseline_simulation.output_dataset.data.map_to_entity( + var_obj.entity, target_entity + ) ) baseline_income = baseline_mapped[self.income_variable] - reform_mapped = self.reform_simulation.output_dataset.data.map_to_entity( - var_obj.entity, target_entity + reform_mapped = ( + self.reform_simulation.output_dataset.data.map_to_entity( + var_obj.entity, target_entity + ) ) reform_income = reform_mapped[self.income_variable] else: @@ -56,14 +64,22 @@ def run(self): reform_income = reform_data[self.income_variable] # Calculate deciles based on baseline income - decile_series = pd.qcut(baseline_income, self.quantiles, labels=False, duplicates='drop') + 1 + decile_series = ( + pd.qcut( + baseline_income, + self.quantiles, + labels=False, + duplicates="drop", + ) + + 1 + ) # Calculate changes absolute_change = reform_income - baseline_income relative_change = (absolute_change / baseline_income) * 100 # Filter to this decile - mask = (decile_series == self.decile) + mask = decile_series == self.decile # Populate results self.baseline_mean = float(baseline_income[mask].mean()) @@ -101,21 +117,23 @@ def calculate_decile_impacts( results.append(impact) # Create DataFrame - df = pd.DataFrame([ - { - "baseline_simulation_id": r.baseline_simulation.id, - "reform_simulation_id": r.reform_simulation.id, - "income_variable": r.income_variable, - "decile": r.decile, - "baseline_mean": r.baseline_mean, - "reform_mean": r.reform_mean, - "absolute_change": r.absolute_change, - "relative_change": r.relative_change, - "count_better_off": r.count_better_off, - "count_worse_off": r.count_worse_off, - "count_no_change": r.count_no_change, - } - for r in results - ]) + df = pd.DataFrame( + [ + { + "baseline_simulation_id": r.baseline_simulation.id, + "reform_simulation_id": r.reform_simulation.id, + "income_variable": r.income_variable, + "decile": r.decile, + "baseline_mean": r.baseline_mean, + "reform_mean": r.reform_mean, + "absolute_change": r.absolute_change, + "relative_change": r.relative_change, + "count_better_off": r.count_better_off, + "count_worse_off": r.count_worse_off, + "count_no_change": r.count_no_change, + } + for r in results + ] + ) return OutputCollection(outputs=results, dataframe=df) diff --git a/src/policyengine/tax_benefit_models/uk.py b/src/policyengine/tax_benefit_models/uk.py index a070033e..0056d54f 100644 --- a/src/policyengine/tax_benefit_models/uk.py +++ b/src/policyengine/tax_benefit_models/uk.py @@ -15,5 +15,9 @@ ] # Rebuild models to resolve forward references +from policyengine.core import Dataset + +Dataset.model_rebuild() +UKYearData.model_rebuild() PolicyEngineUKDataset.model_rebuild() PolicyEngineUKLatest.model_rebuild() diff --git a/src/policyengine/tax_benefit_models/uk/__init__.py b/src/policyengine/tax_benefit_models/uk/__init__.py index 91f58000..f77f3988 100644 --- a/src/policyengine/tax_benefit_models/uk/__init__.py +++ b/src/policyengine/tax_benefit_models/uk/__init__.py @@ -16,3 +16,11 @@ "general_policy_reform_analysis", "ProgrammeStatistics", ] + +# Rebuild models to resolve forward references +from policyengine.core import Dataset + +Dataset.model_rebuild() +UKYearData.model_rebuild() +PolicyEngineUKDataset.model_rebuild() +PolicyEngineUKLatest.model_rebuild() diff --git a/src/policyengine/tax_benefit_models/uk/analysis.py b/src/policyengine/tax_benefit_models/uk/analysis.py index 18ab1f5f..9573cc52 100644 --- a/src/policyengine/tax_benefit_models/uk/analysis.py +++ b/src/policyengine/tax_benefit_models/uk/analysis.py @@ -1,7 +1,10 @@ """General utility functions for UK policy reform analysis.""" from policyengine.core import Simulation, OutputCollection -from policyengine.outputs.decile_impact import DecileImpact, calculate_decile_impacts +from policyengine.outputs.decile_impact import ( + DecileImpact, + calculate_decile_impacts, +) from .outputs import ProgrammeStatistics from pydantic import BaseModel import pandas as pd @@ -62,30 +65,31 @@ def general_policy_reform_analysis( programme_statistics.append(stats) # Create DataFrame - programme_df = pd.DataFrame([ - { - "baseline_simulation_id": p.baseline_simulation.id, - "reform_simulation_id": p.reform_simulation.id, - "programme_name": p.programme_name, - "entity": p.entity, - "is_tax": p.is_tax, - "baseline_total": p.baseline_total, - "reform_total": p.reform_total, - "change": p.change, - "baseline_count": p.baseline_count, - "reform_count": p.reform_count, - "winners": p.winners, - "losers": p.losers, - } - for p in programme_statistics - ]) + programme_df = pd.DataFrame( + [ + { + "baseline_simulation_id": p.baseline_simulation.id, + "reform_simulation_id": p.reform_simulation.id, + "programme_name": p.programme_name, + "entity": p.entity, + "is_tax": p.is_tax, + "baseline_total": p.baseline_total, + "reform_total": p.reform_total, + "change": p.change, + "baseline_count": p.baseline_count, + "reform_count": p.reform_count, + "winners": p.winners, + "losers": p.losers, + } + for p in programme_statistics + ] + ) programme_collection = OutputCollection( - outputs=programme_statistics, - dataframe=programme_df + outputs=programme_statistics, dataframe=programme_df ) return PolicyReformAnalysis( decile_impacts=decile_impacts, - programme_statistics=programme_collection + programme_statistics=programme_collection, ) diff --git a/src/policyengine/tax_benefit_models/uk/datasets.py b/src/policyengine/tax_benefit_models/uk/datasets.py index 073d3a95..309ef4bb 100644 --- a/src/policyengine/tax_benefit_models/uk/datasets.py +++ b/src/policyengine/tax_benefit_models/uk/datasets.py @@ -134,14 +134,17 @@ class PolicyEngineUKDataset(Dataset): data: UKYearData | None = None - def __init__(self, **kwargs: dict): - super().__init__(**kwargs) - + def model_post_init(self, __context): + """Called after Pydantic initialization.""" # Make sure we are synchronised between in-memory and storage, at least on initialisation - if "data" in kwargs: + if self.data is not None: self.save() - elif "filepath" in kwargs: - self.load() + elif self.filepath and not self.data: + try: + self.load() + except FileNotFoundError: + # File doesn't exist yet, that's OK + pass def save(self) -> None: """Save dataset to HDF5 file.""" @@ -188,6 +191,7 @@ def create_datasets( ) -> None: for dataset in datasets: from policyengine_uk import Microsimulation + sim = Microsimulation(dataset=dataset) for year in years: year_dataset = sim.dataset[year] @@ -202,27 +206,40 @@ def create_datasets( household_df[["household_id", "household_weight"]], left_on="person_household_id", right_on="household_id", - how="left" + how="left", + ) + person_df = person_df.rename( + columns={"household_weight": "person_weight"} ) - person_df = person_df.rename(columns={"household_weight": "person_weight"}) person_df = person_df.drop(columns=["household_id"]) # Get household_id for each benunit from person table - benunit_household_map = person_df[["person_benunit_id", "person_household_id"]].drop_duplicates() + benunit_household_map = person_df[ + ["person_benunit_id", "person_household_id"] + ].drop_duplicates() benunit_df = benunit_df.merge( benunit_household_map, left_on="benunit_id", right_on="person_benunit_id", - how="left" + how="left", ) benunit_df = benunit_df.merge( household_df[["household_id", "household_weight"]], left_on="person_household_id", right_on="household_id", - how="left" + how="left", + ) + benunit_df = benunit_df.rename( + columns={"household_weight": "benunit_weight"} + ) + benunit_df = benunit_df.drop( + columns=[ + "person_benunit_id", + "person_household_id", + "household_id", + ], + errors="ignore", ) - benunit_df = benunit_df.rename(columns={"household_weight": "benunit_weight"}) - benunit_df = benunit_df.drop(columns=["person_benunit_id", "person_household_id", "household_id"], errors="ignore") uk_dataset = PolicyEngineUKDataset( name=f"{dataset}-year-{year}", @@ -231,8 +248,12 @@ def create_datasets( year=year, data=UKYearData( person=MicroDataFrame(person_df, weights="person_weight"), - benunit=MicroDataFrame(benunit_df, weights="benunit_weight"), - household=MicroDataFrame(household_df, weights="household_weight"), + benunit=MicroDataFrame( + benunit_df, weights="benunit_weight" + ), + household=MicroDataFrame( + household_df, weights="household_weight" + ), ), ) uk_dataset.save() diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py index d6c0bf72..5a91f5d7 100644 --- a/src/policyengine/tax_benefit_models/uk/model.py +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -1,4 +1,10 @@ -from policyengine.core import TaxBenefitModel, TaxBenefitModelVersion, Variable, Parameter, ParameterValue +from policyengine.core import ( + TaxBenefitModel, + TaxBenefitModelVersion, + Variable, + Parameter, + ParameterValue, +) import datetime import requests from importlib.metadata import version @@ -8,6 +14,7 @@ from pathlib import Path from .datasets import PolicyEngineUKDataset, UKYearData from typing import TYPE_CHECKING + if TYPE_CHECKING: from policyengine.core.simulation import Simulation @@ -103,7 +110,9 @@ def __init__(self, **kwargs: dict): def run(self, simulation: "Simulation") -> "Simulation": from policyengine_uk import Microsimulation from policyengine_uk.data import UKSingleYearDataset - from policyengine.utils.parametric_reforms import simulation_modifier_from_parameter_values + from policyengine.utils.parametric_reforms import ( + simulation_modifier_from_parameter_values, + ) assert isinstance(simulation.dataset, PolicyEngineUKDataset) diff --git a/src/policyengine/tax_benefit_models/uk/outputs.py b/src/policyengine/tax_benefit_models/uk/outputs.py index 6e3d8742..445a37c5 100644 --- a/src/policyengine/tax_benefit_models/uk/outputs.py +++ b/src/policyengine/tax_benefit_models/uk/outputs.py @@ -2,7 +2,10 @@ from policyengine.core import Output from policyengine.outputs.aggregate import Aggregate, AggregateType -from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType +from policyengine.outputs.change_aggregate import ( + ChangeAggregate, + ChangeAggregateType, +) from pydantic import ConfigDict from typing import TYPE_CHECKING diff --git a/src/policyengine/tax_benefit_models/us/__init__.py b/src/policyengine/tax_benefit_models/us/__init__.py index 7e8a6a4c..cbb58218 100644 --- a/src/policyengine/tax_benefit_models/us/__init__.py +++ b/src/policyengine/tax_benefit_models/us/__init__.py @@ -4,7 +4,12 @@ if find_spec("policyengine_us") is not None: from .datasets import USYearData, PolicyEngineUSDataset - from .model import PolicyEngineUS, PolicyEngineUSLatest, us_model, us_latest + from .model import ( + PolicyEngineUS, + PolicyEngineUSLatest, + us_model, + us_latest, + ) from .analysis import general_policy_reform_analysis from .outputs import ProgramStatistics diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index 9a0d01f0..905749f1 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -1,7 +1,10 @@ """General utility functions for US policy reform analysis.""" from policyengine.core import Simulation, OutputCollection -from policyengine.outputs.decile_impact import DecileImpact, calculate_decile_impacts +from policyengine.outputs.decile_impact import ( + DecileImpact, + calculate_decile_impacts, +) from .outputs import ProgramStatistics from pydantic import BaseModel import pandas as pd @@ -65,30 +68,30 @@ def general_policy_reform_analysis( program_statistics.append(stats) # Create DataFrame - program_df = pd.DataFrame([ - { - "baseline_simulation_id": p.baseline_simulation.id, - "reform_simulation_id": p.reform_simulation.id, - "program_name": p.program_name, - "entity": p.entity, - "is_tax": p.is_tax, - "baseline_total": p.baseline_total, - "reform_total": p.reform_total, - "change": p.change, - "baseline_count": p.baseline_count, - "reform_count": p.reform_count, - "winners": p.winners, - "losers": p.losers, - } - for p in program_statistics - ]) + program_df = pd.DataFrame( + [ + { + "baseline_simulation_id": p.baseline_simulation.id, + "reform_simulation_id": p.reform_simulation.id, + "program_name": p.program_name, + "entity": p.entity, + "is_tax": p.is_tax, + "baseline_total": p.baseline_total, + "reform_total": p.reform_total, + "change": p.change, + "baseline_count": p.baseline_count, + "reform_count": p.reform_count, + "winners": p.winners, + "losers": p.losers, + } + for p in program_statistics + ] + ) program_collection = OutputCollection( - outputs=program_statistics, - dataframe=program_df + outputs=program_statistics, dataframe=program_df ) return PolicyReformAnalysis( - decile_impacts=decile_impacts, - program_statistics=program_collection + decile_impacts=decile_impacts, program_statistics=program_collection ) diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py index b9bad8c5..8545cfc7 100644 --- a/src/policyengine/tax_benefit_models/us/datasets.py +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -33,7 +33,14 @@ def map_to_entity( Raises: ValueError: If source or target entity is invalid. """ - valid_entities = {"person", "marital_unit", "family", "spm_unit", "tax_unit", "household"} + valid_entities = { + "person", + "marital_unit", + "family", + "spm_unit", + "tax_unit", + "household", + } if source_entity not in valid_entities: raise ValueError( f"Invalid source entity '{source_entity}'. Must be one of {valid_entities}" @@ -47,7 +54,14 @@ def map_to_entity( source_df = getattr(self, source_entity) if columns: # Select only requested columns (keep join keys) - join_keys = {"person_id", "marital_unit_id", "family_id", "spm_unit_id", "tax_unit_id", "household_id"} + join_keys = { + "person_id", + "marital_unit_id", + "family_id", + "spm_unit_id", + "tax_unit_id", + "household_id", + } cols_to_keep = list( set(columns) | (join_keys & set(source_df.columns)) ) @@ -88,8 +102,13 @@ def map_to_entity( source_key = f"{source_entity}_id" # Link source -> person -> target - if source_key in person_df.columns and target_key in person_df.columns: - person_link = person_df[[source_key, target_key]].drop_duplicates() + if ( + source_key in person_df.columns + and target_key in person_df.columns + ): + person_link = person_df[ + [source_key, target_key] + ].drop_duplicates() source_with_target = pd.DataFrame(source_df).merge( person_link, on=source_key, how="left" ) diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index 676ade1b..9c4196f7 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -1,4 +1,10 @@ -from policyengine.core import TaxBenefitModel, TaxBenefitModelVersion, Variable, Parameter, ParameterValue +from policyengine.core import ( + TaxBenefitModel, + TaxBenefitModelVersion, + Variable, + Parameter, + ParameterValue, +) import datetime import requests from importlib.metadata import version @@ -76,9 +82,7 @@ def __init__(self, **kwargs: dict): name=param_node.name, tax_benefit_model_version=self, description=param_node.description, - data_type=type( - param_node(2025) - ), + data_type=type(param_node(2025)), unit=param_node.metadata.get("unit"), ) self.parameters.append(parameter) @@ -103,7 +107,9 @@ def __init__(self, **kwargs: dict): def run(self, simulation: "Simulation") -> "Simulation": from policyengine_us import Microsimulation - from policyengine.utils.parametric_reforms import simulation_modifier_from_parameter_values + from policyengine.utils.parametric_reforms import ( + simulation_modifier_from_parameter_values, + ) assert isinstance(simulation.dataset, PolicyEngineUSDataset) diff --git a/src/policyengine/tax_benefit_models/us/outputs.py b/src/policyengine/tax_benefit_models/us/outputs.py index dd1131cf..fb54ed5f 100644 --- a/src/policyengine/tax_benefit_models/us/outputs.py +++ b/src/policyengine/tax_benefit_models/us/outputs.py @@ -2,7 +2,10 @@ from policyengine.core import Output from policyengine.outputs.aggregate import Aggregate, AggregateType -from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType +from policyengine.outputs.change_aggregate import ( + ChangeAggregate, + ChangeAggregateType, +) from pydantic import ConfigDict from typing import TYPE_CHECKING diff --git a/src/policyengine/utils/parametric_reforms.py b/src/policyengine/utils/parametric_reforms.py index 81c9fba3..88918ec8 100644 --- a/src/policyengine/utils/parametric_reforms.py +++ b/src/policyengine/utils/parametric_reforms.py @@ -3,7 +3,9 @@ from policyengine_core.periods import period -def simulation_modifier_from_parameter_values(parameter_values: list[ParameterValue]) -> Callable: +def simulation_modifier_from_parameter_values( + parameter_values: list[ParameterValue], +) -> Callable: """ Create a simulation modifier function that applies the given parameter values to a simulation. @@ -16,9 +18,15 @@ def simulation_modifier_from_parameter_values(parameter_values: list[ParameterVa def modifier(simulation): for pv in parameter_values: - p = simulation.tax_benefit_system.parameters.get_child(pv.parameter.name) + p = simulation.tax_benefit_system.parameters.get_child( + pv.parameter.name + ) start_period = period(pv.start_date.strftime("%Y-%m-%d")) - stop_period = period(pv.end_date.strftime("%Y-%m-%d")) if pv.end_date else None + stop_period = ( + period(pv.end_date.strftime("%Y-%m-%d")) + if pv.end_date + else None + ) p.update( value=pv.value, start=start_period, @@ -26,4 +34,4 @@ def modifier(simulation): ) return simulation - return modifier \ No newline at end of file + return modifier diff --git a/tests/test_change_aggregate.py b/tests/test_change_aggregate.py index de878996..6006cbe1 100644 --- a/tests/test_change_aggregate.py +++ b/tests/test_change_aggregate.py @@ -8,37 +8,46 @@ UKYearData, uk_latest, ) -from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType +from policyengine.outputs.change_aggregate import ( + ChangeAggregate, + ChangeAggregateType, +) def test_change_aggregate_count(): """Test counting people with any change.""" person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], - "household_id": [1, 1, 2, 2], - "age": [30, 25, 40, 35], - "employment_income": [50000, 30000, 60000, 40000], - "person_weight": [1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [50000, 30000, 60000, 40000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) benunit_df = MicroDataFrame( - pd.DataFrame({ - "benunit_id": [1, 2], - "benunit_weight": [1.0, 1.0], - }), - weights="benunit_weight" + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", ) household_df = MicroDataFrame( - pd.DataFrame({ - "household_id": [1, 2], - "household_weight": [1.0, 1.0], - }), - weights="household_weight" + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", ) with tempfile.TemporaryDirectory() as tmpdir: @@ -50,20 +59,24 @@ def test_change_aggregate_count(): description="Baseline dataset", filepath=baseline_filepath, year=2024, - data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), ) # Reform: increase everyone's income by 1000 reform_person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], - "household_id": [1, 1, 2, 2], - "age": [30, 25, 40, 35], - "employment_income": [51000, 31000, 61000, 41000], - "person_weight": [1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [51000, 31000, 61000, 41000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) reform_dataset = PolicyEngineUKDataset( @@ -71,7 +84,11 @@ def test_change_aggregate_count(): description="Reform dataset", filepath=reform_filepath, year=2024, - data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=reform_person_df, + benunit=benunit_df, + household=household_df, + ), ) baseline_sim = Simulation( @@ -101,31 +118,37 @@ def test_change_aggregate_count(): def test_change_aggregate_with_absolute_filter(): """Test filtering by absolute change amount.""" person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], - "household_id": [1, 1, 2, 2], - "age": [30, 25, 40, 35], - "employment_income": [50000, 30000, 60000, 40000], - "person_weight": [1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [50000, 30000, 60000, 40000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) benunit_df = MicroDataFrame( - pd.DataFrame({ - "benunit_id": [1, 2], - "benunit_weight": [1.0, 1.0], - }), - weights="benunit_weight" + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", ) household_df = MicroDataFrame( - pd.DataFrame({ - "household_id": [1, 2], - "household_weight": [1.0, 1.0], - }), - weights="household_weight" + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", ) with tempfile.TemporaryDirectory() as tmpdir: @@ -137,20 +160,29 @@ def test_change_aggregate_with_absolute_filter(): description="Baseline dataset", filepath=baseline_filepath, year=2024, - data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), ) # Reform: different gains for different people reform_person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], - "household_id": [1, 1, 2, 2], - "age": [30, 25, 40, 35], - "employment_income": [52000, 30500, 61500, 40200], # Gains: 2000, 500, 1500, 200 - "person_weight": [1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [ + 52000, + 30500, + 61500, + 40200, + ], # Gains: 2000, 500, 1500, 200 + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) reform_dataset = PolicyEngineUKDataset( @@ -158,7 +190,11 @@ def test_change_aggregate_with_absolute_filter(): description="Reform dataset", filepath=reform_filepath, year=2024, - data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=reform_person_df, + benunit=benunit_df, + household=household_df, + ), ) baseline_sim = Simulation( @@ -189,31 +225,37 @@ def test_change_aggregate_with_absolute_filter(): def test_change_aggregate_with_loss_filter(): """Test filtering for losses (negative changes).""" person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], - "household_id": [1, 1, 2, 2], - "age": [30, 25, 40, 35], - "employment_income": [50000, 30000, 60000, 40000], - "person_weight": [1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [50000, 30000, 60000, 40000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) benunit_df = MicroDataFrame( - pd.DataFrame({ - "benunit_id": [1, 2], - "benunit_weight": [1.0, 1.0], - }), - weights="benunit_weight" + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", ) household_df = MicroDataFrame( - pd.DataFrame({ - "household_id": [1, 2], - "household_weight": [1.0, 1.0], - }), - weights="household_weight" + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", ) with tempfile.TemporaryDirectory() as tmpdir: @@ -225,20 +267,29 @@ def test_change_aggregate_with_loss_filter(): description="Baseline dataset", filepath=baseline_filepath, year=2024, - data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), ) # Reform: some people lose money reform_person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], - "household_id": [1, 1, 2, 2], - "age": [30, 25, 40, 35], - "employment_income": [49000, 29000, 60500, 39000], # Changes: -1000, -1000, 500, -1000 - "person_weight": [1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [ + 49000, + 29000, + 60500, + 39000, + ], # Changes: -1000, -1000, 500, -1000 + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) reform_dataset = PolicyEngineUKDataset( @@ -246,7 +297,11 @@ def test_change_aggregate_with_loss_filter(): description="Reform dataset", filepath=reform_filepath, year=2024, - data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=reform_person_df, + benunit=benunit_df, + household=household_df, + ), ) baseline_sim = Simulation( @@ -277,31 +332,37 @@ def test_change_aggregate_with_loss_filter(): def test_change_aggregate_with_relative_filter(): """Test filtering by relative (percentage) change.""" person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], - "household_id": [1, 1, 2, 2], - "age": [30, 25, 40, 35], - "employment_income": [50000, 20000, 60000, 40000], - "person_weight": [1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + "employment_income": [50000, 20000, 60000, 40000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) benunit_df = MicroDataFrame( - pd.DataFrame({ - "benunit_id": [1, 2], - "benunit_weight": [1.0, 1.0], - }), - weights="benunit_weight" + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", ) household_df = MicroDataFrame( - pd.DataFrame({ - "household_id": [1, 2], - "household_weight": [1.0, 1.0], - }), - weights="household_weight" + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", ) with tempfile.TemporaryDirectory() as tmpdir: @@ -313,21 +374,25 @@ def test_change_aggregate_with_relative_filter(): description="Baseline dataset", filepath=baseline_filepath, year=2024, - data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), ) # Reform: different percentage gains reform_person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], - "household_id": [1, 1, 2, 2], - "age": [30, 25, 40, 35], - # Gains: 5000 (10%), 2000 (10%), 3000 (5%), 1000 (2.5%) - "employment_income": [55000, 22000, 63000, 41000], - "person_weight": [1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 35], + # Gains: 5000 (10%), 2000 (10%), 3000 (5%), 1000 (2.5%) + "employment_income": [55000, 22000, 63000, 41000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) reform_dataset = PolicyEngineUKDataset( @@ -335,7 +400,11 @@ def test_change_aggregate_with_relative_filter(): description="Reform dataset", filepath=reform_filepath, year=2024, - data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=reform_person_df, + benunit=benunit_df, + household=household_df, + ), ) baseline_sim = Simulation( @@ -366,31 +435,37 @@ def test_change_aggregate_with_relative_filter(): def test_change_aggregate_sum(): """Test summing changes.""" person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3], - "benunit_id": [1, 1, 2], - "household_id": [1, 1, 2], - "age": [30, 25, 40], - "employment_income": [50000, 30000, 60000], - "person_weight": [1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) benunit_df = MicroDataFrame( - pd.DataFrame({ - "benunit_id": [1, 2], - "benunit_weight": [1.0, 1.0], - }), - weights="benunit_weight" + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", ) household_df = MicroDataFrame( - pd.DataFrame({ - "household_id": [1, 2], - "household_weight": [1.0, 1.0], - }), - weights="household_weight" + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", ) with tempfile.TemporaryDirectory() as tmpdir: @@ -402,20 +477,24 @@ def test_change_aggregate_sum(): description="Baseline dataset", filepath=baseline_filepath, year=2024, - data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), ) # Reform: everyone gains 1000 reform_person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3], - "benunit_id": [1, 1, 2], - "household_id": [1, 1, 2], - "age": [30, 25, 40], - "employment_income": [51000, 31000, 61000], - "person_weight": [1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [51000, 31000, 61000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) reform_dataset = PolicyEngineUKDataset( @@ -423,7 +502,11 @@ def test_change_aggregate_sum(): description="Reform dataset", filepath=reform_filepath, year=2024, - data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=reform_person_df, + benunit=benunit_df, + household=household_df, + ), ) baseline_sim = Simulation( @@ -453,31 +536,37 @@ def test_change_aggregate_sum(): def test_change_aggregate_mean(): """Test mean change.""" person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3], - "benunit_id": [1, 1, 2], - "household_id": [1, 1, 2], - "age": [30, 25, 40], - "employment_income": [50000, 30000, 60000], - "person_weight": [1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [50000, 30000, 60000], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) benunit_df = MicroDataFrame( - pd.DataFrame({ - "benunit_id": [1, 2], - "benunit_weight": [1.0, 1.0], - }), - weights="benunit_weight" + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", ) household_df = MicroDataFrame( - pd.DataFrame({ - "household_id": [1, 2], - "household_weight": [1.0, 1.0], - }), - weights="household_weight" + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", ) with tempfile.TemporaryDirectory() as tmpdir: @@ -489,20 +578,28 @@ def test_change_aggregate_mean(): description="Baseline dataset", filepath=baseline_filepath, year=2024, - data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), ) # Reform: different gains reform_person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3], - "benunit_id": [1, 1, 2], - "household_id": [1, 1, 2], - "age": [30, 25, 40], - "employment_income": [51000, 32000, 63000], # Gains: 1000, 2000, 3000 - "person_weight": [1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3], + "benunit_id": [1, 1, 2], + "household_id": [1, 1, 2], + "age": [30, 25, 40], + "employment_income": [ + 51000, + 32000, + 63000, + ], # Gains: 1000, 2000, 3000 + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) reform_dataset = PolicyEngineUKDataset( @@ -510,7 +607,11 @@ def test_change_aggregate_mean(): description="Reform dataset", filepath=reform_filepath, year=2024, - data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=reform_person_df, + benunit=benunit_df, + household=household_df, + ), ) baseline_sim = Simulation( @@ -540,31 +641,37 @@ def test_change_aggregate_mean(): def test_change_aggregate_with_filter_variable(): """Test filtering by another variable (e.g., only adults).""" person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], - "household_id": [1, 1, 2, 2], - "age": [30, 25, 40, 15], # Person 4 is a child - "employment_income": [50000, 30000, 60000, 5000], - "person_weight": [1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 15], # Person 4 is a child + "employment_income": [50000, 30000, 60000, 5000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) benunit_df = MicroDataFrame( - pd.DataFrame({ - "benunit_id": [1, 2], - "benunit_weight": [1.0, 1.0], - }), - weights="benunit_weight" + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", ) household_df = MicroDataFrame( - pd.DataFrame({ - "household_id": [1, 2], - "household_weight": [1.0, 1.0], - }), - weights="household_weight" + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", ) with tempfile.TemporaryDirectory() as tmpdir: @@ -576,20 +683,24 @@ def test_change_aggregate_with_filter_variable(): description="Baseline dataset", filepath=baseline_filepath, year=2024, - data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), ) # Reform: everyone gains 1000 reform_person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], - "household_id": [1, 1, 2, 2], - "age": [30, 25, 40, 15], - "employment_income": [51000, 31000, 61000, 6000], - "person_weight": [1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "benunit_id": [1, 1, 2, 2], + "household_id": [1, 1, 2, 2], + "age": [30, 25, 40, 15], + "employment_income": [51000, 31000, 61000, 6000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) reform_dataset = PolicyEngineUKDataset( @@ -597,7 +708,11 @@ def test_change_aggregate_with_filter_variable(): description="Reform dataset", filepath=reform_filepath, year=2024, - data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=reform_person_df, + benunit=benunit_df, + household=household_df, + ), ) baseline_sim = Simulation( @@ -630,31 +745,37 @@ def test_change_aggregate_with_filter_variable(): def test_change_aggregate_combined_filters(): """Test combining multiple filter types.""" person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4, 5], - "benunit_id": [1, 1, 2, 2, 3], - "household_id": [1, 1, 2, 2, 3], - "age": [30, 25, 40, 35, 45], - "employment_income": [50000, 20000, 60000, 40000, 80000], - "person_weight": [1.0, 1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4, 5], + "benunit_id": [1, 1, 2, 2, 3], + "household_id": [1, 1, 2, 2, 3], + "age": [30, 25, 40, 35, 45], + "employment_income": [50000, 20000, 60000, 40000, 80000], + "person_weight": [1.0, 1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) benunit_df = MicroDataFrame( - pd.DataFrame({ - "benunit_id": [1, 2, 3], - "benunit_weight": [1.0, 1.0, 1.0], - }), - weights="benunit_weight" + pd.DataFrame( + { + "benunit_id": [1, 2, 3], + "benunit_weight": [1.0, 1.0, 1.0], + } + ), + weights="benunit_weight", ) household_df = MicroDataFrame( - pd.DataFrame({ - "household_id": [1, 2, 3], - "household_weight": [1.0, 1.0, 1.0], - }), - weights="household_weight" + pd.DataFrame( + { + "household_id": [1, 2, 3], + "household_weight": [1.0, 1.0, 1.0], + } + ), + weights="household_weight", ) with tempfile.TemporaryDirectory() as tmpdir: @@ -666,21 +787,25 @@ def test_change_aggregate_combined_filters(): description="Baseline dataset", filepath=baseline_filepath, year=2024, - data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=person_df, benunit=benunit_df, household=household_df + ), ) # Reform: varying gains reform_person_df = MicroDataFrame( - pd.DataFrame({ - "person_id": [1, 2, 3, 4, 5], - "benunit_id": [1, 1, 2, 2, 3], - "household_id": [1, 1, 2, 2, 3], - "age": [30, 25, 40, 35, 45], - # Changes: 10000 (20%), 2000 (10%), 3000 (5%), 800 (2%), 4000 (5%) - "employment_income": [60000, 22000, 63000, 40800, 84000], - "person_weight": [1.0, 1.0, 1.0, 1.0, 1.0], - }), - weights="person_weight" + pd.DataFrame( + { + "person_id": [1, 2, 3, 4, 5], + "benunit_id": [1, 1, 2, 2, 3], + "household_id": [1, 1, 2, 2, 3], + "age": [30, 25, 40, 35, 45], + # Changes: 10000 (20%), 2000 (10%), 3000 (5%), 800 (2%), 4000 (5%) + "employment_income": [60000, 22000, 63000, 40800, 84000], + "person_weight": [1.0, 1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", ) reform_dataset = PolicyEngineUKDataset( @@ -688,7 +813,11 @@ def test_change_aggregate_combined_filters(): description="Reform dataset", filepath=reform_filepath, year=2024, - data=UKYearData(person=reform_person_df, benunit=benunit_df, household=household_df), + data=UKYearData( + person=reform_person_df, + benunit=benunit_df, + household=household_df, + ), ) baseline_sim = Simulation( From bc61df730188988098ce00801aafee44472d7e4c Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 16 Nov 2025 20:27:11 +0000 Subject: [PATCH 25/35] US works! --- docs/visualisation.md | 72 ++++ examples/employment_income_variation.py | 279 ------------- examples/employment_income_variation_uk.py | 383 ++++++++++++++++++ examples/employment_income_variation_us.py | 356 ++++++++++++++++ .../tax_benefit_models/us/__init__.py | 7 + .../tax_benefit_models/us/datasets.py | 42 +- .../tax_benefit_models/us/model.py | 138 ++++++- src/policyengine/utils/__init__.py | 1 + src/policyengine/utils/dates.py | 18 +- src/policyengine/utils/plotting.py | 178 ++++++++ tests/test_us_simulation.py | 260 ++++++++++++ 11 files changed, 1428 insertions(+), 306 deletions(-) create mode 100644 docs/visualisation.md delete mode 100644 examples/employment_income_variation.py create mode 100644 examples/employment_income_variation_uk.py create mode 100644 examples/employment_income_variation_us.py create mode 100644 src/policyengine/utils/plotting.py create mode 100644 tests/test_us_simulation.py diff --git a/docs/visualisation.md b/docs/visualisation.md new file mode 100644 index 00000000..639f12ae --- /dev/null +++ b/docs/visualisation.md @@ -0,0 +1,72 @@ +# Visualisation utilities + +PolicyEngine provides utilities for creating publication-ready charts that follow our visual style guidelines. + +## Formatting plotly figures + +The `format_fig()` function applies PolicyEngine's visual style to plotly figures, ensuring consistency across all analyses and publications. + +```python +from policyengine.utils import format_fig, COLORS +import plotly.graph_objects as go + +# Create your figure +fig = go.Figure() +fig.add_trace(go.Scatter(x=[1, 2, 3], y=[4, 5, 6], name="Data")) + +# Apply PolicyEngine styling +format_fig( + fig, + title="Example chart", + xaxis_title="X axis", + yaxis_title="Y axis", + height=600, + width=800 +) + +fig.show() +``` + +## Visual style principles + +The formatting applies these principles automatically: + +**Colours**: Primary teal (#319795) with semantic colours for different data types (success/green, warning/yellow, error/red, info/blue). Access colours via the `COLORS` dictionary: + +```python +from policyengine.utils import COLORS + +fig.add_trace(go.Scatter( + x=x_data, + y=y_data, + line=dict(color=COLORS["primary"]) +)) +``` + +**Typography**: Inter font family with appropriate sizing (12px for labels, 14px for body text, 16px for titles). + +**Layout**: Clean white background with subtle grey gridlines and appropriate margins (48px) for professional presentation. + +**Clarity**: Data-driven design that prioritises immediate understanding over decoration. + +## Available colours + +```python +COLORS = { + "primary": "#319795", # Teal (main brand colour) + "primary_light": "#E6FFFA", + "primary_dark": "#1D4044", + "success": "#22C55E", # Green (positive changes) + "warning": "#FEC601", # Yellow (cautions) + "error": "#EF4444", # Red (negative changes) + "info": "#1890FF", # Blue (neutral information) + "gray_light": "#F2F4F7", + "gray": "#667085", + "gray_dark": "#101828", + "blue_secondary": "#026AA2", +} +``` + +## Complete example + +See `examples/employment_income_variation.py` for a full demonstration of using `format_fig()` in an analysis workflow. diff --git a/examples/employment_income_variation.py b/examples/employment_income_variation.py deleted file mode 100644 index 7fa02b7a..00000000 --- a/examples/employment_income_variation.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Example: Vary employment income and plot HBAI household net income. - -This script demonstrates: -1. Creating a custom dataset with a single household template -2. Varying employment income from £0 to £100k -3. Running a single simulation for all variations -4. Using Aggregate with filters to extract results by employment income -5. Visualising the relationship between employment income and net income - -Run: python examples/employment_income_variation.py -""" - -import pandas as pd -import tempfile -from pathlib import Path -import plotly.graph_objects as go -from microdf import MicroDataFrame -from policyengine.core import Simulation -from policyengine.tax_benefit_models.uk import ( - PolicyEngineUKDataset, - UKYearData, - uk_latest, -) -from policyengine.outputs.aggregate import Aggregate, AggregateType - - -def create_dataset_with_varied_employment_income( - employment_incomes: list[float], year: int = 2026 -) -> PolicyEngineUKDataset: - """Create a dataset with one household template, varied by employment income. - - Each household is a single adult with varying employment income. - Everything else about the household is kept constant. - """ - n_households = len(employment_incomes) - - # Create person data - one adult per household with varying employment income - person_data = { - "person_id": list(range(n_households)), - "person_benunit_id": list(range(n_households)), # Link to benunit - "person_household_id": list(range(n_households)), # Link to household - "age": [35] * n_households, # Single adult, age 35 - "employment_income": employment_incomes, - "person_weight": [1.0] * n_households, - } - - # Create benunit data - one per household - benunit_data = { - "benunit_id": list(range(n_households)), - "benunit_weight": [1.0] * n_households, - } - - # Create household data - one per employment income level - household_data = { - "household_id": list(range(n_households)), - "household_weight": [1.0] * n_households, - "region": ["LONDON"] * n_households, # Required by policyengine-uk - "council_tax": [0.0] * n_households, # Simplified - no council tax - "rent": [0.0] * n_households, # Simplified - no rent - "tenure_type": ["RENT_PRIVATELY"] - * n_households, # Required for uprating - } - - # Create MicroDataFrames - person_df = MicroDataFrame( - pd.DataFrame(person_data), weights="person_weight" - ) - benunit_df = MicroDataFrame( - pd.DataFrame(benunit_data), weights="benunit_weight" - ) - household_df = MicroDataFrame( - pd.DataFrame(household_data), weights="household_weight" - ) - - # Create temporary file - tmpdir = tempfile.mkdtemp() - filepath = str(Path(tmpdir) / "employment_income_variation.h5") - - # Create dataset - dataset = PolicyEngineUKDataset( - name="Employment income variation", - description="Single adult household with varying employment income", - filepath=filepath, - year=year, - data=UKYearData( - person=person_df, - benunit=benunit_df, - household=household_df, - ), - ) - - return dataset - - -def run_simulation(dataset: PolicyEngineUKDataset) -> Simulation: - """Run a single simulation for all employment income variations.""" - simulation = Simulation( - dataset=dataset, - tax_benefit_model_version=uk_latest, - ) - simulation.run() - return simulation - - -def extract_results_by_employment_income( - simulation: Simulation, employment_incomes: list[float] -) -> dict: - """Extract HBAI household net income and components for each employment income level. - - Uses Aggregate with filters to extract specific households. - """ - hbai_net_income = [] - household_benefits = [] - household_tax = [] - employment_income_hh = [] - - for emp_income in employment_incomes: - # Get HBAI household net income - agg = Aggregate( - simulation=simulation, - variable="hbai_household_net_income", - aggregate_type=AggregateType.MEAN, - filter_variable="employment_income", - filter_variable_eq=emp_income, - entity="household", - ) - agg.run() - hbai_net_income.append(agg.result) - - # Get household benefits - agg = Aggregate( - simulation=simulation, - variable="household_benefits", - aggregate_type=AggregateType.MEAN, - filter_variable="employment_income", - filter_variable_eq=emp_income, - entity="household", - ) - agg.run() - household_benefits.append(agg.result) - - # Get household tax - agg = Aggregate( - simulation=simulation, - variable="household_tax", - aggregate_type=AggregateType.MEAN, - filter_variable="employment_income", - filter_variable_eq=emp_income, - entity="household", - ) - agg.run() - household_tax.append(agg.result) - - # Get employment income at household level - agg = Aggregate( - simulation=simulation, - variable="employment_income", - aggregate_type=AggregateType.MEAN, - filter_variable="employment_income", - filter_variable_eq=emp_income, - entity="household", - ) - agg.run() - employment_income_hh.append(agg.result) - - return { - "employment_income": employment_incomes, - "hbai_household_net_income": hbai_net_income, - "household_benefits": household_benefits, - "household_tax": household_tax, - "employment_income_hh": employment_income_hh, - } - - -def visualise_results(results: dict) -> None: - """Create a line chart showing HBAI household net income and components.""" - fig = go.Figure() - - # Main HBAI net income line - fig.add_trace( - go.Scatter( - x=results["employment_income"], - y=results["hbai_household_net_income"], - mode="lines+markers", - name="HBAI household net income", - line=dict(color="darkblue", width=3), - marker=dict(size=5), - ) - ) - - # Employment income (gross) - fig.add_trace( - go.Scatter( - x=results["employment_income"], - y=results["employment_income_hh"], - mode="lines", - name="Employment income (gross)", - line=dict(color="green", width=2, dash="dot"), - ) - ) - - # Household benefits - fig.add_trace( - go.Scatter( - x=results["employment_income"], - y=results["household_benefits"], - mode="lines", - name="Household benefits", - line=dict(color="orange", width=2), - ) - ) - - # Household tax (negative for visual clarity) - fig.add_trace( - go.Scatter( - x=results["employment_income"], - y=[-t for t in results["household_tax"]], - mode="lines", - name="Household tax (negative)", - line=dict(color="red", width=2), - ) - ) - - fig.update_layout( - title="HBAI household net income and components by employment income", - xaxis_title="Employment income (£)", - yaxis_title="Amount (£)", - height=600, - width=1000, - showlegend=True, - legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), - ) - - fig.show() - - -def main(): - """Main execution function.""" - # Create employment income range from £0 to £100k - # Using smaller intervals at lower incomes where the relationship is more interesting - employment_incomes = ( - list(range(0, 20000, 1000)) # £0 to £20k in £1k steps - + list(range(20000, 50000, 2500)) # £20k to £50k in £2.5k steps - + list(range(50000, 100001, 5000)) # £50k to £100k in £5k steps - ) - - print( - f"Creating dataset with {len(employment_incomes)} employment income variations..." - ) - dataset = create_dataset_with_varied_employment_income(employment_incomes) - - print("Running simulation (single run for all variations)...") - simulation = run_simulation(dataset) - - print("Extracting results using aggregate filters...") - results = extract_results_by_employment_income( - simulation, employment_incomes - ) - - print("\nSample results:") - print( - f"Employment income £0: HBAI net income £{results['hbai_household_net_income'][0]:,.0f}" - ) - print( - f"Employment income £25k: HBAI net income £{results['hbai_household_net_income'][employment_incomes.index(25000)]:,.0f}" - ) - print( - f"Employment income £50k: HBAI net income £{results['hbai_household_net_income'][employment_incomes.index(50000)]:,.0f}" - ) - print( - f"Employment income £100k: HBAI net income £{results['hbai_household_net_income'][-1]:,.0f}" - ) - - print("\nGenerating visualisation...") - visualise_results(results) - - -if __name__ == "__main__": - main() diff --git a/examples/employment_income_variation_uk.py b/examples/employment_income_variation_uk.py new file mode 100644 index 00000000..7cd29a28 --- /dev/null +++ b/examples/employment_income_variation_uk.py @@ -0,0 +1,383 @@ +"""Example: Vary employment income and plot HBAI household net income. + +This script demonstrates: +1. Creating a custom dataset with a single household template +2. Varying employment income from £0 to £100k +3. Running a single simulation for all variations +4. Using Aggregate with filters to extract results by employment income +5. Visualising the relationship between employment income and net income + +IMPORTANT NOTES FOR CUSTOM DATASETS: +- Always set would_claim_* variables to True, otherwise benefits won't be claimed + even if the household is eligible (they default to random/False) +- Always set disability variables explicitly (is_disabled_for_benefits, uc_limited_capability_for_WRA) + to prevent random UC spikes from LCWRA element (£5,241/year extra if randomly assigned) +- Must include join keys: person_benunit_id, person_household_id in person data +- Required household fields: region, council_tax, rent, tenure_type +- Person-level variables are mapped to household level using weights + +Run: python examples/employment_income_variation.py +""" + +import pandas as pd +import tempfile +from pathlib import Path +import plotly.graph_objects as go +from microdf import MicroDataFrame +from policyengine.core import Simulation +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, + uk_latest, +) +from policyengine.outputs.aggregate import Aggregate, AggregateType +from policyengine.utils.plotting import format_fig, COLORS + + +def create_dataset_with_varied_employment_income( + employment_incomes: list[float], year: int = 2026 +) -> PolicyEngineUKDataset: + """Create a dataset with one household template, varied by employment income. + + Each household is a single adult with 2 children, paying median UK rent. + Employment income varies across households. + """ + n_households = len(employment_incomes) + n_people = n_households * 3 # 1 adult + 2 children per household + + # Create person data - one adult + 2 children per household + person_ids = [] + benunit_ids = [] + household_ids = [] + ages = [] + employment_income = [] + person_weights = [] + is_disabled = [] + limited_capability = [] + + person_id_counter = 0 + for hh_idx in range(n_households): + # Adult + person_ids.append(person_id_counter) + benunit_ids.append(hh_idx) + household_ids.append(hh_idx) + ages.append(35) + employment_income.append(employment_incomes[hh_idx]) + person_weights.append(1.0) + is_disabled.append(False) + limited_capability.append(False) + person_id_counter += 1 + + # Child 1 (age 8) + person_ids.append(person_id_counter) + benunit_ids.append(hh_idx) + household_ids.append(hh_idx) + ages.append(8) + employment_income.append(0.0) + person_weights.append(1.0) + is_disabled.append(False) + limited_capability.append(False) + person_id_counter += 1 + + # Child 2 (age 5) + person_ids.append(person_id_counter) + benunit_ids.append(hh_idx) + household_ids.append(hh_idx) + ages.append(5) + employment_income.append(0.0) + person_weights.append(1.0) + is_disabled.append(False) + limited_capability.append(False) + person_id_counter += 1 + + person_data = { + "person_id": person_ids, + "person_benunit_id": benunit_ids, + "person_household_id": household_ids, + "age": ages, + "employment_income": employment_income, + "person_weight": person_weights, + "is_disabled_for_benefits": is_disabled, + "uc_limited_capability_for_WRA": limited_capability, + } + + # Create benunit data - one per household + benunit_data = { + "benunit_id": list(range(n_households)), + "benunit_weight": [1.0] * n_households, + # Would claim variables - MUST set to True or benefits won't be claimed! + "would_claim_uc": [True] * n_households, + "would_claim_WTC": [True] * n_households, + "would_claim_CTC": [True] * n_households, + "would_claim_IS": [True] * n_households, + "would_claim_pc": [True] * n_households, + "would_claim_child_benefit": [True] * n_households, + "would_claim_housing_benefit": [True] * n_households, + } + + # Create household data - one per employment income level + median_annual_rent = 850 * 12 # £850/month = £10,200/year (median UK rent) + household_data = { + "household_id": list(range(n_households)), + "household_weight": [1.0] * n_households, + "region": ["LONDON"] * n_households, # Required by policyengine-uk + "council_tax": [0.0] * n_households, # Simplified - no council tax + "rent": [median_annual_rent] * n_households, # Median UK rent + "tenure_type": ["RENT_PRIVATELY"] + * n_households, # Required for uprating + } + + # Create MicroDataFrames + person_df = MicroDataFrame( + pd.DataFrame(person_data), weights="person_weight" + ) + benunit_df = MicroDataFrame( + pd.DataFrame(benunit_data), weights="benunit_weight" + ) + household_df = MicroDataFrame( + pd.DataFrame(household_data), weights="household_weight" + ) + + # Create temporary file + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "employment_income_variation.h5") + + # Create dataset + dataset = PolicyEngineUKDataset( + name="Employment income variation", + description="Single adult household with varying employment income", + filepath=filepath, + year=year, + data=UKYearData( + person=person_df, + benunit=benunit_df, + household=household_df, + ), + ) + + return dataset + + +def run_simulation(dataset: PolicyEngineUKDataset) -> Simulation: + """Run a single simulation for all employment income variations.""" + # Specify additional variables to calculate beyond defaults + variables = { + "household": [ + # Default variables + "household_id", + "household_weight", + "household_net_income", + "hbai_household_net_income", + "household_benefits", + "household_tax", + ], + "person": [ + "person_id", + "benunit_id", + "household_id", + "person_weight", + "employment_income", + "age", + ], + "benunit": [ + "benunit_id", + "benunit_weight", + # Individual benefits (at benunit level) + "universal_credit", + "child_benefit", + "working_tax_credit", + "child_tax_credit", + "pension_credit", + "income_support", + ], + } + + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + variables=variables, + ) + simulation.run() + return simulation + + +def extract_results_by_employment_income( + simulation: Simulation, employment_incomes: list[float] +) -> dict: + """Extract HBAI household net income and components for each employment income level. + + Uses Aggregate with filters to extract specific households. + """ + hbai_net_income = [] + household_benefits = [] + household_tax = [] + employment_income_hh = [] + + # Individual benefits + universal_credit = [] + child_benefit = [] + working_tax_credit = [] + child_tax_credit = [] + pension_credit = [] + income_support = [] + + + for hh_idx, emp_income in enumerate(employment_incomes): + # Get HBAI household net income + agg = Aggregate( + simulation=simulation, + variable="hbai_household_net_income", + aggregate_type=AggregateType.MEAN, + filter_variable="household_id", + filter_variable_eq=hh_idx, + entity="household", + ) + agg.run() + hbai_net_income.append(agg.result) + + # Get household benefits + agg = Aggregate( + simulation=simulation, + variable="household_benefits", + aggregate_type=AggregateType.MEAN, + filter_variable="household_id", + filter_variable_eq=hh_idx, + entity="household", + ) + agg.run() + household_benefits.append(agg.result) + + # Get individual benefits (at benunit level, but we have 1:1 benunit:household mapping) + for benefit_name, benefit_list in [ + ("universal_credit", universal_credit), + ("child_benefit", child_benefit), + ("working_tax_credit", working_tax_credit), + ("child_tax_credit", child_tax_credit), + ("pension_credit", pension_credit), + ("income_support", income_support), + ]: + agg = Aggregate( + simulation=simulation, + variable=benefit_name, + aggregate_type=AggregateType.MEAN, + filter_variable="benunit_id", + filter_variable_eq=hh_idx, + entity="benunit", + ) + agg.run() + benefit_list.append(agg.result) + + + # Get household tax + agg = Aggregate( + simulation=simulation, + variable="household_tax", + aggregate_type=AggregateType.MEAN, + filter_variable="household_id", + filter_variable_eq=hh_idx, + entity="household", + ) + agg.run() + household_tax.append(agg.result) + + # Employment income at household level (just use the input value) + employment_income_hh.append(emp_income) + + return { + "employment_income": employment_incomes, + "hbai_household_net_income": hbai_net_income, + "household_benefits": household_benefits, + "household_tax": household_tax, + "employment_income_hh": employment_income_hh, + "universal_credit": universal_credit, + "child_benefit": child_benefit, + "working_tax_credit": working_tax_credit, + "child_tax_credit": child_tax_credit, + "pension_credit": pension_credit, + "income_support": income_support, + } + + +def visualise_results(results: dict) -> None: + """Create a stacked area chart showing income composition.""" + fig = go.Figure() + + # Calculate net employment income (employment income minus tax) + net_employment = [ + emp - tax + for emp, tax in zip(results["employment_income_hh"], results["household_tax"]) + ] + + # Stack benefits and income components using PolicyEngine colors + components = [ + ("Net employment income", net_employment, COLORS["primary"]), + ("Universal Credit", results["universal_credit"], COLORS["blue_secondary"]), + ("Working Tax Credit", results["working_tax_credit"], COLORS["info"]), + ("Child Tax Credit", results["child_tax_credit"], COLORS["success"]), + ("Child Benefit", results["child_benefit"], COLORS["warning"]), + ("Pension Credit", results["pension_credit"], COLORS["gray"]), + ("Income Support", results["income_support"], COLORS["gray_dark"]), + ] + + for name, values, color in components: + fig.add_trace( + go.Scatter( + x=results["employment_income"], + y=values, + name=name, + mode="lines", + line=dict(width=0.5, color=color), + stackgroup="one", + fillcolor=color, + ) + ) + + # Apply PolicyEngine styling + format_fig( + fig, + title="Household net income composition by employment income", + xaxis_title="Employment income (£)", + yaxis_title="Net income (£)", + show_legend=True, + height=700, + width=1200, + ) + + fig.show() + + +def main(): + """Main execution function.""" + # Create employment income range from £0 to £100k + # Using smaller intervals at lower incomes where the relationship is more interesting + employment_incomes = ( + list(range(0, 20000, 1000)) # £0 to £20k in £1k steps + + list(range(20000, 50000, 2500)) # £20k to £50k in £2.5k steps + + list(range(50000, 100001, 5000)) # £50k to £100k in £5k steps + ) + + print( + f"Creating dataset with {len(employment_incomes)} employment income variations..." + ) + dataset = create_dataset_with_varied_employment_income(employment_incomes) + + print("Running simulation (single run for all variations)...") + simulation = run_simulation(dataset) + + print("Extracting results using aggregate filters...") + results = extract_results_by_employment_income( + simulation, employment_incomes + ) + + print("\nSample results:") + for emp_inc in [0, 25000, 50000, 100000]: + idx = employment_incomes.index(emp_inc) if emp_inc in employment_incomes else -1 + if idx >= 0: + print(f" Employment income £{emp_inc:,}: HBAI net income £{results['hbai_household_net_income'][idx]:,.0f}") + + print("\nGenerating visualisation...") + visualise_results(results) + + +if __name__ == "__main__": + main() diff --git a/examples/employment_income_variation_us.py b/examples/employment_income_variation_us.py new file mode 100644 index 00000000..855c2e88 --- /dev/null +++ b/examples/employment_income_variation_us.py @@ -0,0 +1,356 @@ +"""Example: Vary employment income and plot household net income (US). + +This script demonstrates: +1. Creating a custom dataset with a single household template +2. Varying employment income from $0 to $200k +3. Running a single simulation for all variations +4. Using Aggregate with filters to extract results by employment income +5. Visualising the relationship between employment income and net income + +Run: python examples/employment_income_variation_us.py +""" + +import pandas as pd +import tempfile +from pathlib import Path +import plotly.graph_objects as go +from microdf import MicroDataFrame +from policyengine.core import Simulation +from policyengine.tax_benefit_models.us import ( + PolicyEngineUSDataset, + USYearData, + us_latest, +) +from policyengine.outputs.aggregate import Aggregate, AggregateType +from policyengine.utils.plotting import format_fig, COLORS + + +def create_dataset_with_varied_employment_income( + employment_incomes: list[float], year: int = 2024 +) -> PolicyEngineUSDataset: + """Create a dataset with one household template, varied by employment income. + + Each household is a single adult with 2 children. + Employment income varies across households. + """ + n_households = len(employment_incomes) + n_people = n_households * 3 # 1 adult + 2 children per household + + # Create person data - one adult + 2 children per household + person_ids = [] + household_ids = [] + marital_unit_ids = [] + family_ids = [] + spm_unit_ids = [] + tax_unit_ids = [] + ages = [] + employment_income = [] + person_weights = [] + + person_id_counter = 0 + for hh_idx in range(n_households): + # Adult + person_ids.append(person_id_counter) + household_ids.append(hh_idx) + marital_unit_ids.append(hh_idx) + family_ids.append(hh_idx) + spm_unit_ids.append(hh_idx) + tax_unit_ids.append(hh_idx) + ages.append(35) + employment_income.append(employment_incomes[hh_idx]) + person_weights.append(1000.0) + person_id_counter += 1 + + # Child 1 (age 8) + person_ids.append(person_id_counter) + household_ids.append(hh_idx) + marital_unit_ids.append(hh_idx) + family_ids.append(hh_idx) + spm_unit_ids.append(hh_idx) + tax_unit_ids.append(hh_idx) + ages.append(8) + employment_income.append(0.0) + person_weights.append(1000.0) + person_id_counter += 1 + + # Child 2 (age 5) + person_ids.append(person_id_counter) + household_ids.append(hh_idx) + marital_unit_ids.append(hh_idx) + family_ids.append(hh_idx) + spm_unit_ids.append(hh_idx) + tax_unit_ids.append(hh_idx) + ages.append(5) + employment_income.append(0.0) + person_weights.append(1000.0) + person_id_counter += 1 + + person_data = { + "person_id": person_ids, + "household_id": household_ids, + "marital_unit_id": marital_unit_ids, + "family_id": family_ids, + "spm_unit_id": spm_unit_ids, + "tax_unit_id": tax_unit_ids, + "age": ages, + "employment_income": employment_income, + "person_weight": person_weights, + } + + # Create household data + household_data = { + "household_id": list(range(n_households)), + "state_name": ["CA"] * n_households, # California + "household_weight": [1000.0] * n_households, + } + + # Create group entity data + marital_unit_data = { + "marital_unit_id": list(range(n_households)), + "marital_unit_weight": [1000.0] * n_households, + } + + family_data = { + "family_id": list(range(n_households)), + "family_weight": [1000.0] * n_households, + } + + spm_unit_data = { + "spm_unit_id": list(range(n_households)), + "spm_unit_weight": [1000.0] * n_households, + } + + tax_unit_data = { + "tax_unit_id": list(range(n_households)), + "tax_unit_weight": [1000.0] * n_households, + } + + # Create MicroDataFrames + person_df = MicroDataFrame( + pd.DataFrame(person_data), weights="person_weight" + ) + household_df = MicroDataFrame( + pd.DataFrame(household_data), weights="household_weight" + ) + marital_unit_df = MicroDataFrame( + pd.DataFrame(marital_unit_data), weights="marital_unit_weight" + ) + family_df = MicroDataFrame( + pd.DataFrame(family_data), weights="family_weight" + ) + spm_unit_df = MicroDataFrame( + pd.DataFrame(spm_unit_data), weights="spm_unit_weight" + ) + tax_unit_df = MicroDataFrame( + pd.DataFrame(tax_unit_data), weights="tax_unit_weight" + ) + + # Create temporary file + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "employment_income_variation_us.h5") + + # Create dataset + dataset = PolicyEngineUSDataset( + name="Employment income variation (US)", + description="Single adult household with 2 children, varying employment income", + filepath=filepath, + year=year, + data=USYearData( + person=person_df, + household=household_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + tax_unit=tax_unit_df, + ), + ) + + return dataset + + +def run_simulation(dataset: PolicyEngineUSDataset) -> Simulation: + """Run a single simulation for all employment income variations.""" + # Specify variables to calculate + variables = { + "household": [ + "household_id", + "household_weight", + "household_net_income", + "household_benefits", + "household_tax", + "household_market_income", + ], + "person": [ + "person_id", + "household_id", + "marital_unit_id", + "family_id", + "spm_unit_id", + "tax_unit_id", + "person_weight", + "employment_income", + "age", + ], + "spm_unit": [ + "spm_unit_id", + "spm_unit_weight", + "snap", + "tanf", + "spm_unit_net_income", + ], + "tax_unit": [ + "tax_unit_id", + "tax_unit_weight", + "income_tax", + "employee_payroll_tax", + "eitc", + "ctc", + ], + "marital_unit": [ + "marital_unit_id", + "marital_unit_weight", + ], + "family": [ + "family_id", + "family_weight", + ], + } + + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=us_latest, + variables=variables, + ) + simulation.run() + return simulation + + +def extract_results_by_employment_income( + simulation: Simulation, employment_incomes: list[float] +) -> dict: + """Extract household net income and components for each employment income level. + + Directly accesses output data by row index since we have one household per income level. + """ + import pandas as pd + + # Get output data + household_df = pd.DataFrame(simulation.output_dataset.data.household) + spm_unit_df = pd.DataFrame(simulation.output_dataset.data.spm_unit) + tax_unit_df = pd.DataFrame(simulation.output_dataset.data.tax_unit) + + # Extract results (one row per household/spm_unit/tax_unit) + household_net_income = household_df["household_net_income"].tolist() + household_benefits = household_df["household_benefits"].tolist() + household_tax = household_df["household_tax"].tolist() + + snap = spm_unit_df["snap"].tolist() + tanf = spm_unit_df["tanf"].tolist() + + eitc = tax_unit_df["eitc"].tolist() + ctc = tax_unit_df["ctc"].tolist() + + employment_income_hh = employment_incomes + + return { + "employment_income": employment_incomes, + "household_net_income": household_net_income, + "household_benefits": household_benefits, + "household_tax": household_tax, + "employment_income_hh": employment_income_hh, + "snap": snap, + "tanf": tanf, + "eitc": eitc, + "ctc": ctc, + } + + +def visualise_results(results: dict) -> None: + """Create a stacked area chart showing income composition.""" + fig = go.Figure() + + # Calculate net employment income (employment income minus tax) + net_employment = [ + emp - tax + for emp, tax in zip( + results["employment_income_hh"], results["household_tax"] + ) + ] + + # Stack benefits and income components using PolicyEngine colors + components = [ + ("Net employment income", net_employment, COLORS["primary"]), + ("SNAP", results["snap"], COLORS["blue_secondary"]), + ("TANF", results["tanf"], COLORS["info"]), + ("EITC", results["eitc"], COLORS["success"]), + ("CTC", results["ctc"], COLORS["warning"]), + ] + + for name, values, color in components: + fig.add_trace( + go.Scatter( + x=results["employment_income"], + y=values, + name=name, + mode="lines", + line=dict(width=0.5, color=color), + stackgroup="one", + fillcolor=color, + ) + ) + + # Apply PolicyEngine styling + format_fig( + fig, + title="Household net income composition by employment income", + xaxis_title="Employment income ($)", + yaxis_title="Net income ($)", + show_legend=True, + height=700, + width=1200, + ) + + fig.show() + + +def main(): + """Main execution function.""" + # Create employment income range from $0 to $200k + # Using smaller intervals at lower incomes where the relationship is more interesting + employment_incomes = ( + list(range(0, 40000, 2000)) # $0 to $40k in $2k steps + + list(range(40000, 100000, 5000)) # $40k to $100k in $5k steps + + list(range(100000, 200001, 10000)) # $100k to $200k in $10k steps + ) + + print( + f"Creating dataset with {len(employment_incomes)} employment income variations..." + ) + dataset = create_dataset_with_varied_employment_income(employment_incomes) + + print("Running simulation (single run for all variations)...") + simulation = run_simulation(dataset) + + print("Extracting results using aggregate filters...") + results = extract_results_by_employment_income( + simulation, employment_incomes + ) + + print("\nSample results:") + for emp_inc in [0, 50000, 100000, 200000]: + idx = ( + employment_incomes.index(emp_inc) + if emp_inc in employment_incomes + else -1 + ) + if idx >= 0: + print( + f" Employment income ${emp_inc:,}: household net income ${results['household_net_income'][idx]:,.0f}" + ) + + print("\nGenerating visualisation...") + visualise_results(results) + + +if __name__ == "__main__": + main() diff --git a/src/policyengine/tax_benefit_models/us/__init__.py b/src/policyengine/tax_benefit_models/us/__init__.py index cbb58218..b4c0805a 100644 --- a/src/policyengine/tax_benefit_models/us/__init__.py +++ b/src/policyengine/tax_benefit_models/us/__init__.py @@ -3,6 +3,7 @@ from importlib.util import find_spec if find_spec("policyengine_us") is not None: + from policyengine.core import Dataset from .datasets import USYearData, PolicyEngineUSDataset from .model import ( PolicyEngineUS, @@ -13,6 +14,12 @@ from .analysis import general_policy_reform_analysis from .outputs import ProgramStatistics + # Rebuild Pydantic models to resolve forward references + Dataset.model_rebuild() + USYearData.model_rebuild() + PolicyEngineUSDataset.model_rebuild() + PolicyEngineUSLatest.model_rebuild() + __all__ = [ "USYearData", "PolicyEngineUSDataset", diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py index 8545cfc7..4027feeb 100644 --- a/src/policyengine/tax_benefit_models/us/datasets.py +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -86,16 +86,27 @@ def map_to_entity( # Map to different entity target_df = getattr(self, target_entity) - - # Direct mapping if join key exists in source target_key = f"{target_entity}_id" - if target_key in pd.DataFrame(source_df).columns: - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on=target_key, how="left" - ) - return MicroDataFrame(result, weights=target_weight) - # For more complex mappings, go through person table + # Person to group entity: aggregate person-level data to group level + if source_entity == "person" and target_entity != "person": + if target_key in pd.DataFrame(source_df).columns: + # Merge source (person) with target (group) on target_key + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on=target_key, how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Group entity to person: expand group-level data to person level + if source_entity != "person" and target_entity == "person": + source_key = f"{source_entity}_id" + if source_key in pd.DataFrame(target_df).columns: + result = pd.DataFrame(target_df).merge( + pd.DataFrame(source_df), on=source_key, how="left" + ) + return MicroDataFrame(result, weights=target_weight) + + # Group to group: go through person table if source_entity != "person" and target_entity != "person": # Get person link table with both entity IDs person_df = pd.DataFrame(self.person) @@ -127,14 +138,17 @@ class PolicyEngineUSDataset(Dataset): data: USYearData | None = None - def __init__(self, **kwargs: dict): - super().__init__(**kwargs) - + def model_post_init(self, __context): + """Called after Pydantic initialization.""" # Make sure we are synchronised between in-memory and storage, at least on initialisation - if "data" in kwargs: + if self.data is not None: self.save() - elif "filepath" in kwargs: - self.load() + elif self.filepath and not self.data: + try: + self.load() + except FileNotFoundError: + # File doesn't exist yet, that's OK + pass def save(self) -> None: """Save dataset to HDF5 file.""" diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index 9c4196f7..34a3528d 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -26,22 +26,29 @@ class PolicyEngineUS(TaxBenefitModel): us_model = PolicyEngineUS() -pkg_version = version("policyengine-us") -# Get published time from PyPI -response = requests.get("https://pypi.org/pypi/policyengine-us/json") -data = response.json() -upload_time = data["releases"][pkg_version][0]["upload_time_iso_8601"] +def _get_us_package_metadata(): + """Get PolicyEngine US package version and upload time (lazy-loaded).""" + pkg_version = version("policyengine-us") + # Get published time from PyPI + response = requests.get("https://pypi.org/pypi/policyengine-us/json") + data = response.json() + upload_time = data["releases"][pkg_version][0]["upload_time_iso_8601"] + return pkg_version, upload_time class PolicyEngineUSLatest(TaxBenefitModelVersion): model: TaxBenefitModel = us_model - version: str = pkg_version - created_at: datetime.datetime = datetime.datetime.fromisoformat( - upload_time - ) + version: str = None + created_at: datetime.datetime = None def __init__(self, **kwargs: dict): + # Lazy-load package metadata if not provided + if "version" not in kwargs or kwargs.get("version") is None: + pkg_version, upload_time = _get_us_package_metadata() + kwargs["version"] = pkg_version + kwargs["created_at"] = datetime.datetime.fromisoformat(upload_time) + super().__init__(**kwargs) from policyengine_us.system import system from policyengine_core.enums import Enum @@ -107,16 +114,25 @@ def __init__(self, **kwargs: dict): def run(self, simulation: "Simulation") -> "Simulation": from policyengine_us import Microsimulation + from policyengine_us.system import system + from policyengine_core.simulations.simulation_builder import ( + SimulationBuilder, + ) from policyengine.utils.parametric_reforms import ( simulation_modifier_from_parameter_values, ) + import numpy as np assert isinstance(simulation.dataset, PolicyEngineUSDataset) dataset = simulation.dataset dataset.load() - microsim = Microsimulation(dataset=None) + # Build simulation from entity IDs using PolicyEngine Core pattern + microsim = Microsimulation() + self._build_simulation_from_dataset(microsim, dataset, system) + + # Apply policy reforms if ( simulation.policy and simulation.policy.simulation_modifier is not None @@ -128,6 +144,7 @@ def run(self, simulation: "Simulation") -> "Simulation": ) modifier(microsim) + # Apply dynamic reforms if ( simulation.dynamic and simulation.dynamic.simulation_modifier is not None @@ -192,8 +209,7 @@ def run(self, simulation: "Simulation") -> "Simulation": "tax_unit_id", "tax_unit_weight", "income_tax", - "payroll_tax", - "state_income_tax", + "employee_payroll_tax", "eitc", "ctc", "adjusted_gross_income", @@ -263,5 +279,103 @@ def run(self, simulation: "Simulation") -> "Simulation": simulation.output_dataset.save() + def _build_simulation_from_dataset(self, microsim, dataset, system): + """Build a PolicyEngine Core simulation from dataset entity IDs. + + This follows the same pattern as policyengine-uk, initializing + entities from IDs first, then using set_input() for variables. + + Args: + microsim: The Microsimulation object to populate + dataset: The dataset containing entity data + system: The tax-benefit system + """ + from policyengine_core.simulations.simulation_builder import ( + SimulationBuilder, + ) + import numpy as np + + # Create builder and instantiate entities + builder = SimulationBuilder() + builder.populations = system.instantiate_entities() + + # Extract entity IDs from dataset + person_data = pd.DataFrame(dataset.data.person) + + # Declare entities + builder.declare_person_entity( + "person", person_data["person_id"].values + ) + builder.declare_entity( + "household", np.unique(person_data["household_id"].values) + ) + builder.declare_entity( + "spm_unit", np.unique(person_data["spm_unit_id"].values) + ) + builder.declare_entity( + "family", np.unique(person_data["family_id"].values) + ) + builder.declare_entity( + "tax_unit", np.unique(person_data["tax_unit_id"].values) + ) + builder.declare_entity( + "marital_unit", np.unique(person_data["marital_unit_id"].values) + ) + + # Join persons to group entities + builder.join_with_persons( + builder.populations["household"], + person_data["household_id"].values, + np.array(["member"] * len(person_data)), + ) + builder.join_with_persons( + builder.populations["spm_unit"], + person_data["spm_unit_id"].values, + np.array(["member"] * len(person_data)), + ) + builder.join_with_persons( + builder.populations["family"], + person_data["family_id"].values, + np.array(["member"] * len(person_data)), + ) + builder.join_with_persons( + builder.populations["tax_unit"], + person_data["tax_unit_id"].values, + np.array(["member"] * len(person_data)), + ) + builder.join_with_persons( + builder.populations["marital_unit"], + person_data["marital_unit_id"].values, + np.array(["member"] * len(person_data)), + ) + + # Build simulation from populations + microsim.build_from_populations(builder.populations) + + # Set input variables for each entity + # Skip ID columns as they're structural and already used in entity building + id_columns = { + "person_id", + "household_id", + "spm_unit_id", + "family_id", + "tax_unit_id", + "marital_unit_id", + } + + for entity_name, entity_df in [ + ("person", dataset.data.person), + ("household", dataset.data.household), + ("spm_unit", dataset.data.spm_unit), + ("family", dataset.data.family), + ("tax_unit", dataset.data.tax_unit), + ("marital_unit", dataset.data.marital_unit), + ]: + df = pd.DataFrame(entity_df) + for column in df.columns: + # Skip ID columns and check if variable exists in system + if column not in id_columns and column in system.variables: + microsim.set_input(column, dataset.year, df[column].values) + us_latest = PolicyEngineUSLatest() diff --git a/src/policyengine/utils/__init__.py b/src/policyengine/utils/__init__.py index 6761220d..ac764329 100644 --- a/src/policyengine/utils/__init__.py +++ b/src/policyengine/utils/__init__.py @@ -1 +1,2 @@ from .dates import parse_safe_date +from .plotting import format_fig, COLORS diff --git a/src/policyengine/utils/dates.py b/src/policyengine/utils/dates.py index 6bcacab1..d2439456 100644 --- a/src/policyengine/utils/dates.py +++ b/src/policyengine/utils/dates.py @@ -1,9 +1,11 @@ from datetime import datetime +import calendar def parse_safe_date(date_string: str) -> datetime: """ Parse a YYYY-MM-DD date string and ensure the year is at least 1. + Handles invalid day values by capping to the last valid day of the month. Args: date_string: Date string in YYYY-MM-DD format @@ -18,7 +20,21 @@ def parse_safe_date(date_string: str) -> datetime: # Replace year 0 or negative years with year 1 return date_obj.replace(year=1) return date_obj - except ValueError: + except ValueError as e: + # Try to handle invalid day values (e.g., 2021-06-31) + if "day is out of range for month" in str(e): + parts = date_string.split("-") + if len(parts) == 3: + year = int(parts[0]) + month = int(parts[1]) + # Get the last valid day of the month + last_day = calendar.monthrange(year, month)[1] + # Use the last valid day instead + corrected_date = f"{year:04d}-{month:02d}-{last_day:02d}" + date_obj = datetime.strptime(corrected_date, "%Y-%m-%d") + if date_obj.year < 1: + return date_obj.replace(year=1) + return date_obj raise ValueError( f"Invalid date format: {date_string}. Expected YYYY-MM-DD" ) diff --git a/src/policyengine/utils/plotting.py b/src/policyengine/utils/plotting.py new file mode 100644 index 00000000..661ab19e --- /dev/null +++ b/src/policyengine/utils/plotting.py @@ -0,0 +1,178 @@ +"""Plotting utilities for PolicyEngine visualisations.""" + +from typing import Optional +import plotly.graph_objects as go + + +# PolicyEngine brand colours +COLORS = { + "primary": "#319795", # Teal + "primary_light": "#E6FFFA", + "primary_dark": "#1D4044", + "success": "#22C55E", # Green (positive changes) + "warning": "#FEC601", # Yellow (cautions) + "error": "#EF4444", # Red (negative changes) + "info": "#1890FF", # Blue (neutral info) + "gray_light": "#F2F4F7", + "gray": "#667085", + "gray_dark": "#101828", + "blue_secondary": "#026AA2", +} + +# Typography +FONT_FAMILY = "Inter, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" +FONT_SIZE_LABEL = 12 +FONT_SIZE_DEFAULT = 14 +FONT_SIZE_TITLE = 16 + + +def format_fig( + fig: go.Figure, + title: Optional[str] = None, + xaxis_title: Optional[str] = None, + yaxis_title: Optional[str] = None, + show_legend: bool = True, + height: Optional[int] = None, + width: Optional[int] = None, +) -> go.Figure: + """Apply PolicyEngine visual style to a plotly figure. + + Applies professional, clean styling following PolicyEngine design principles: + - Data-driven clarity prioritising immediate understanding + - Professional brand colours (teal primary, semantic colours) + - Clean typography with Inter font family + - Minimal visual clutter + - Appropriate spacing and margins + + Args: + fig: Plotly figure to format + title: Optional title to set/override + xaxis_title: Optional x-axis title to set/override + yaxis_title: Optional y-axis title to set/override + show_legend: Whether to show the legend (default: True) + height: Optional height in pixels + width: Optional width in pixels + + Returns: + Formatted plotly figure (same object, modified in place) + + Example: + >>> import plotly.graph_objects as go + >>> from policyengine.utils import format_fig + >>> fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6])) + >>> format_fig(fig, title="Example chart", xaxis_title="X", yaxis_title="Y") + """ + # Build layout updates + layout_updates = { + "font": { + "family": FONT_FAMILY, + "size": FONT_SIZE_DEFAULT, + "color": COLORS["gray_dark"], + }, + "plot_bgcolor": "#FAFAFA", + "paper_bgcolor": "white", + "margin": {"l": 100, "r": 60, "t": 100, "b": 80}, + "showlegend": show_legend, + "xaxis": { + "title": { + "font": { + "size": FONT_SIZE_DEFAULT, + "family": FONT_FAMILY, + "color": COLORS["gray_dark"], + }, + "standoff": 20, + }, + "tickfont": { + "size": FONT_SIZE_LABEL, + "family": FONT_FAMILY, + "color": COLORS["gray"], + }, + "showgrid": False, + "showline": True, + "linewidth": 2, + "linecolor": COLORS["gray_light"], + "zeroline": False, + "ticks": "outside", + "tickwidth": 1, + "tickcolor": COLORS["gray_light"], + }, + "yaxis": { + "title": { + "font": { + "size": FONT_SIZE_DEFAULT, + "family": FONT_FAMILY, + "color": COLORS["gray_dark"], + }, + "standoff": 20, + }, + "tickfont": { + "size": FONT_SIZE_LABEL, + "family": FONT_FAMILY, + "color": COLORS["gray"], + }, + "showgrid": True, + "gridwidth": 1, + "gridcolor": "#E5E7EB", + "showline": False, + "zeroline": False, + }, + "legend": { + "bgcolor": "white", + "bordercolor": COLORS["gray_light"], + "borderwidth": 1, + "font": {"size": FONT_SIZE_LABEL, "family": FONT_FAMILY}, + "orientation": "v", + "yanchor": "top", + "y": 0.99, + "xanchor": "right", + "x": 0.99, + }, + } + + # Add optional parameters + if title is not None: + layout_updates["title"] = { + "text": title, + "font": { + "size": 18, + "family": FONT_FAMILY, + "color": COLORS["gray_dark"], + "weight": 600, + }, + "x": 0, + "xanchor": "left", + "y": 0.98, + "yanchor": "top", + } + + if xaxis_title is not None: + layout_updates["xaxis"]["title"]["text"] = xaxis_title + + if yaxis_title is not None: + layout_updates["yaxis"]["title"]["text"] = yaxis_title + + if height is not None: + layout_updates["height"] = height + + if width is not None: + layout_updates["width"] = width + + # Apply layout + fig.update_layout(**layout_updates) + + # Update all traces to have cleaner styling + fig.update_traces( + marker=dict(size=8, line=dict(width=0)), + line=dict(width=3), + selector=dict(mode="markers+lines"), + ) + fig.update_traces( + marker=dict(size=8, line=dict(width=0)), + selector=dict(mode="markers"), + ) + fig.update_traces( + line=dict(width=3), + selector=dict(mode="lines"), + ) + + return fig diff --git a/tests/test_us_simulation.py b/tests/test_us_simulation.py new file mode 100644 index 00000000..b3df9a67 --- /dev/null +++ b/tests/test_us_simulation.py @@ -0,0 +1,260 @@ +import pandas as pd +import tempfile +import os +from microdf import MicroDataFrame +from policyengine.core import Simulation +from policyengine.tax_benefit_models.us import ( + PolicyEngineUSDataset, + USYearData, + us_latest, +) + + +def test_us_latest_instantiation(): + """Test that us_latest can be instantiated without errors.""" + assert us_latest is not None + assert us_latest.version is not None + assert us_latest.model is not None + assert us_latest.created_at is not None + assert ( + len(us_latest.variables) > 0 + ) # Should have variables from policyengine-us + + +def test_save_and_load_us_dataset(): + """Test saving and loading a US dataset.""" + # Create sample data with minimal required columns + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [0, 1, 2], + "household_id": [0, 0, 1], + "marital_unit_id": [0, 0, 1], + "family_id": [0, 0, 1], + "spm_unit_id": [0, 0, 1], + "tax_unit_id": [0, 0, 1], + "age": [30, 35, 25], + "employment_income": [50000, 60000, 40000], + "person_weight": [1000.0, 1000.0, 1000.0], + } + ), + weights="person_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [0, 1], + "household_weight": [1000.0, 1000.0], + } + ), + weights="household_weight", + ) + + marital_unit_df = MicroDataFrame( + pd.DataFrame( + { + "marital_unit_id": [0, 1], + "marital_unit_weight": [1000.0, 1000.0], + } + ), + weights="marital_unit_weight", + ) + + family_df = MicroDataFrame( + pd.DataFrame( + { + "family_id": [0, 1], + "family_weight": [1000.0, 1000.0], + } + ), + weights="family_weight", + ) + + spm_unit_df = MicroDataFrame( + pd.DataFrame( + { + "spm_unit_id": [0, 1], + "spm_unit_weight": [1000.0, 1000.0], + } + ), + weights="spm_unit_weight", + ) + + tax_unit_df = MicroDataFrame( + pd.DataFrame( + { + "tax_unit_id": [0, 1], + "tax_unit_weight": [1000.0, 1000.0], + } + ), + weights="tax_unit_weight", + ) + + # Create dataset + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "test_us_dataset.h5") + + dataset = PolicyEngineUSDataset( + name="Test US Dataset", + description="A test US dataset", + filepath=filepath, + year=2024, + data=USYearData( + person=person_df, + household=household_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + tax_unit=tax_unit_df, + ), + ) + + # Save to file + dataset.save() + + # Load it back + loaded = PolicyEngineUSDataset( + name="Loaded US Dataset", + description="Loaded from file", + filepath=filepath, + year=2024, + ) + loaded.load() + + # Verify data + assert loaded.year == 2024 + pd.testing.assert_frame_equal( + pd.DataFrame(loaded.data.person), pd.DataFrame(person_df) + ) + pd.testing.assert_frame_equal( + pd.DataFrame(loaded.data.household), pd.DataFrame(household_df) + ) + + +def test_us_simulation_from_dataset(): + """Test running a US simulation from a dataset using PolicyEngine Core pattern.""" + # Create a small test dataset + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [0, 1], + "household_id": [0, 0], + "marital_unit_id": [0, 0], + "family_id": [0, 0], + "spm_unit_id": [0, 0], + "tax_unit_id": [0, 0], + "age": [30, 35], + "employment_income": [50000, 60000], + "person_weight": [1000.0, 1000.0], + } + ), + weights="person_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [0], + "state_name": ["CA"], + "household_weight": [1000.0], + } + ), + weights="household_weight", + ) + + marital_unit_df = MicroDataFrame( + pd.DataFrame( + { + "marital_unit_id": [0], + "marital_unit_weight": [1000.0], + } + ), + weights="marital_unit_weight", + ) + + family_df = MicroDataFrame( + pd.DataFrame( + { + "family_id": [0], + "family_weight": [1000.0], + } + ), + weights="family_weight", + ) + + spm_unit_df = MicroDataFrame( + pd.DataFrame( + { + "spm_unit_id": [0], + "spm_unit_weight": [1000.0], + } + ), + weights="spm_unit_weight", + ) + + tax_unit_df = MicroDataFrame( + pd.DataFrame( + { + "tax_unit_id": [0], + "tax_unit_weight": [1000.0], + } + ), + weights="tax_unit_weight", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "test_simulation.h5") + + dataset = PolicyEngineUSDataset( + name="Test Simulation Dataset", + description="Dataset for testing simulation", + filepath=filepath, + year=2024, + data=USYearData( + person=person_df, + household=household_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + tax_unit=tax_unit_df, + ), + ) + + # Create and run simulation + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=us_latest, + variables={ + "person": [ + "person_id", + "person_weight", + "age", + "employment_income", + ], + "household": ["household_id", "household_weight"], + "marital_unit": ["marital_unit_id", "marital_unit_weight"], + "family": ["family_id", "family_weight"], + "spm_unit": ["spm_unit_id", "spm_unit_weight"], + "tax_unit": ["tax_unit_id", "tax_unit_weight"], + }, + ) + + simulation.run() + + # Verify output dataset was created + assert simulation.output_dataset is not None + assert simulation.output_dataset.data is not None + + # Verify person data contains the expected variables + person_output = pd.DataFrame(simulation.output_dataset.data.person) + assert "person_id" in person_output.columns + assert "age" in person_output.columns + assert "employment_income" in person_output.columns + assert len(person_output) == 2 # Should have 2 people + + # Verify employment income values match input + assert person_output["employment_income"].tolist() == [ + 50000, + 60000, + ] From 6d11846c8bb9e798343e3216833077ddfe84485b Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 16 Nov 2025 21:51:36 +0000 Subject: [PATCH 26/35] Add macro US output --- .env.example | 15 - .../{income_bands.py => income_bands_uk.py} | 0 examples/income_distribution_us.py | 343 ++++++++++++++++++ .../{policy_change.py => policy_change_uk.py} | 0 .../tax_benefit_models/us/__init__.py | 3 +- .../tax_benefit_models/us/datasets.py | 180 +++++++++ .../tax_benefit_models/us/model.py | 70 ++-- tests/test_us_datasets.py | 108 ++++++ 8 files changed, 680 insertions(+), 39 deletions(-) delete mode 100644 .env.example rename examples/{income_bands.py => income_bands_uk.py} (100%) create mode 100644 examples/income_distribution_us.py rename examples/{policy_change.py => policy_change_uk.py} (100%) create mode 100644 tests/test_us_datasets.py diff --git a/.env.example b/.env.example deleted file mode 100644 index 413cd7c1..00000000 --- a/.env.example +++ /dev/null @@ -1,15 +0,0 @@ -## Copy this file to `.env` and fill in the values as needed. - -# Local development database (default if not set) -DATABASE_URL=sqlite:///policyengine.db - -# PolicyEngine live database connection pieces (used when --db-location policyengine) -# The CLI composes the URL as postgresql+psycopg2://... with sslmode=require by default. -POLICYENGINE_DB_PASSWORD= -POLICYENGINE_DB_USER=postgres -POLICYENGINE_DB_HOST=db.usugnrssspkdutcjeevk.supabase.co -POLICYENGINE_DB_PORT=5432 -POLICYENGINE_DB_NAME=postgres - -# Optional: Hugging Face token for private repos when seeding datasets from HF -HUGGING_FACE_TOKEN= diff --git a/examples/income_bands.py b/examples/income_bands_uk.py similarity index 100% rename from examples/income_bands.py rename to examples/income_bands_uk.py diff --git a/examples/income_distribution_us.py b/examples/income_distribution_us.py new file mode 100644 index 00000000..02173845 --- /dev/null +++ b/examples/income_distribution_us.py @@ -0,0 +1,343 @@ +"""Example: Plot US household income distribution using enhanced CPS microdata. + +This script demonstrates: +1. Loading enhanced CPS representative household microdata +2. Running a full microsimulation to calculate household income and tax +3. Using Aggregate to calculate statistics within income deciles +4. Visualising the income distribution across the United States + +Run: python examples/income_distribution_us.py +""" + +from pathlib import Path +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from policyengine.core import Simulation +from policyengine.tax_benefit_models.us import ( + PolicyEngineUSDataset, + us_latest, +) +from policyengine.outputs.aggregate import Aggregate, AggregateType +from policyengine.utils.plotting import format_fig, COLORS + + +def load_representative_data(year: int = 2024) -> PolicyEngineUSDataset: + """Load representative household microdata for a given year.""" + dataset_path = ( + Path(__file__).parent / "data" / f"enhanced_cps_2024_year_{year}.h5" + ) + + if not dataset_path.exists(): + raise FileNotFoundError( + f"Dataset not found at {dataset_path}. " + "Run create_datasets() from policyengine.tax_benefit_models.us first." + ) + + dataset = PolicyEngineUSDataset( + name=f"Enhanced CPS {year}", + description=f"Representative household microdata for {year}", + filepath=str(dataset_path), + year=year, + ) + dataset.load() + return dataset + + +def run_simulation(dataset: PolicyEngineUSDataset) -> Simulation: + """Run a microsimulation on the dataset.""" + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=us_latest, + ) + simulation.run() + return simulation + + +def calculate_income_decile_statistics(simulation: Simulation) -> dict: + """Calculate total income, tax, and benefits by income deciles.""" + deciles = [f"D{i}" for i in range(1, 11)] + market_incomes = [] + taxes = [] + benefits = [] + net_incomes = [] + counts = [] + + # Calculate household-level aggregates by decile + for decile_num in range(1, 11): + agg = Aggregate( + simulation=simulation, + variable="household_market_income", + aggregate_type=AggregateType.SUM, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile_num, + ) + agg.run() + market_incomes.append(agg.result / 1e9) + + agg = Aggregate( + simulation=simulation, + variable="household_tax", + aggregate_type=AggregateType.SUM, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile_num, + ) + agg.run() + taxes.append(agg.result / 1e9) + + agg = Aggregate( + simulation=simulation, + variable="household_benefits", + aggregate_type=AggregateType.SUM, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile_num, + ) + agg.run() + benefits.append(agg.result / 1e9) + + agg = Aggregate( + simulation=simulation, + variable="household_net_income", + aggregate_type=AggregateType.SUM, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile_num, + ) + agg.run() + net_incomes.append(agg.result / 1e9) + + agg = Aggregate( + simulation=simulation, + variable="household_weight", + aggregate_type=AggregateType.SUM, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile_num, + ) + agg.run() + counts.append(agg.result / 1e6) + + # Calculate individual benefit programs by decile + benefit_programs_by_decile = {} + + # Person-level benefits (mapped to household for decile filtering) + for prog in [ + "ssi", + "social_security", + "medicaid", + "unemployment_compensation", + ]: + prog_by_decile = [] + for decile_num in range(1, 11): + agg = Aggregate( + simulation=simulation, + variable=prog, + entity="household", + aggregate_type=AggregateType.SUM, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile_num, + ) + agg.run() + prog_by_decile.append(agg.result / 1e9) + benefit_programs_by_decile[prog] = prog_by_decile + + # SPM unit benefits (mapped to household for decile filtering) + for prog in ["snap", "tanf"]: + prog_by_decile = [] + for decile_num in range(1, 11): + agg = Aggregate( + simulation=simulation, + variable=prog, + entity="household", + aggregate_type=AggregateType.SUM, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile_num, + ) + agg.run() + prog_by_decile.append(agg.result / 1e9) + benefit_programs_by_decile[prog] = prog_by_decile + + # Tax unit benefits (mapped to household for decile filtering) + for prog in ["eitc", "ctc"]: + prog_by_decile = [] + for decile_num in range(1, 11): + agg = Aggregate( + simulation=simulation, + variable=prog, + entity="household", + aggregate_type=AggregateType.SUM, + filter_variable="household_net_income", + quantile=10, + quantile_eq=decile_num, + ) + agg.run() + prog_by_decile.append(agg.result / 1e9) + benefit_programs_by_decile[prog] = prog_by_decile + + return { + "deciles": deciles, + "market_incomes": market_incomes, + "taxes": taxes, + "benefits": benefits, + "net_incomes": net_incomes, + "counts": counts, + "benefit_programs_by_decile": benefit_programs_by_decile, + } + + +def visualise_results(results: dict) -> None: + """Create visualisations of income distribution.""" + # Create overview figure + fig = make_subplots( + rows=2, + cols=2, + subplot_titles=( + "Market income by decile ($bn)", + "Tax by decile ($bn)", + "Benefits by program and decile ($bn)", + "Households by decile (millions)", + ), + specs=[ + [{"type": "bar"}, {"type": "bar"}], + [{"type": "bar"}, {"type": "bar"}], + ], + ) + + # Market income + fig.add_trace( + go.Bar( + x=results["deciles"], + y=results["market_incomes"], + marker_color=COLORS["primary"], + name="Market income", + ), + row=1, + col=1, + ) + + # Tax + fig.add_trace( + go.Bar( + x=results["deciles"], + y=results["taxes"], + marker_color=COLORS["error"], + name="Tax", + ), + row=1, + col=2, + ) + + # Benefits by program (stacked) + benefit_programs = [ + ("Social Security", "social_security"), + ("Medicaid", "medicaid"), + ("SNAP", "snap"), + ("EITC", "eitc"), + ("CTC", "ctc"), + ("SSI", "ssi"), + ("TANF", "tanf"), + ("Unemployment", "unemployment_compensation"), + ] + + for name, key in benefit_programs: + if key in results["benefit_programs_by_decile"]: + fig.add_trace( + go.Bar( + x=results["deciles"], + y=results["benefit_programs_by_decile"][key], + name=name, + ), + row=2, + col=1, + ) + + # Household counts + fig.add_trace( + go.Bar( + x=results["deciles"], + y=results["counts"], + marker_color=COLORS["info"], + name="Households", + ), + row=2, + col=2, + ) + + fig.update_xaxes(title_text="Income decile", row=1, col=1) + fig.update_xaxes(title_text="Income decile", row=1, col=2) + fig.update_xaxes(title_text="Income decile", row=2, col=1) + fig.update_xaxes(title_text="Income decile", row=2, col=2) + + fig.update_layout( + title_text="US household income distribution (Enhanced CPS 2024)", + showlegend=True, + barmode="stack", + height=800, + width=1200, + legend=dict(orientation="h", yanchor="bottom", y=-0.15, xanchor="center", x=0.5), + ) + + fig.show() + + +def main(): + """Main execution function.""" + print("Loading enhanced CPS representative household data...") + dataset = load_representative_data(year=2024) + + print( + f"Dataset loaded: {len(dataset.data.person):,} people, {len(dataset.data.household):,} households" + ) + + print("Running microsimulation...") + simulation = run_simulation(dataset) + + print("Calculating statistics by income decile...") + results = calculate_income_decile_statistics(simulation) + + print("\nResults summary:") + total_market_income = sum(results["market_incomes"]) + total_tax = sum(results["taxes"]) + total_benefits = sum(results["benefits"]) + total_net_income = sum(results["net_incomes"]) + total_households = sum(results["counts"]) + + print(f"Total market income: ${total_market_income:.1f}bn") + print(f"Total tax: ${total_tax:.1f}bn") + print(f"Total benefits: ${total_benefits:.1f}bn") + print(f"Total net income: ${total_net_income:.1f}bn") + print(f"Total households: {total_households:.1f}m") + print( + f"Average effective tax rate: {total_tax / total_market_income * 100:.1f}%" + ) + + print("\nBenefit programs by decile:") + benefit_programs = [ + ("Social Security", "social_security"), + ("Medicaid", "medicaid"), + ("SNAP", "snap"), + ("EITC", "eitc"), + ("CTC", "ctc"), + ("SSI", "ssi"), + ("TANF", "tanf"), + ("Unemployment", "unemployment_compensation"), + ] + + for name, key in benefit_programs: + if key in results["benefit_programs_by_decile"]: + total = sum(results["benefit_programs_by_decile"][key]) + print(f"\n {name} (total: ${total:.1f}bn):") + for i, decile in enumerate(results["deciles"]): + value = results["benefit_programs_by_decile"][key][i] + print(f" {decile}: ${value:.1f}bn") + + print("\nGenerating visualisations...") + visualise_results(results) + + +if __name__ == "__main__": + main() diff --git a/examples/policy_change.py b/examples/policy_change_uk.py similarity index 100% rename from examples/policy_change.py rename to examples/policy_change_uk.py diff --git a/src/policyengine/tax_benefit_models/us/__init__.py b/src/policyengine/tax_benefit_models/us/__init__.py index b4c0805a..8c273fa0 100644 --- a/src/policyengine/tax_benefit_models/us/__init__.py +++ b/src/policyengine/tax_benefit_models/us/__init__.py @@ -4,7 +4,7 @@ if find_spec("policyengine_us") is not None: from policyengine.core import Dataset - from .datasets import USYearData, PolicyEngineUSDataset + from .datasets import USYearData, PolicyEngineUSDataset, create_datasets from .model import ( PolicyEngineUS, PolicyEngineUSLatest, @@ -23,6 +23,7 @@ __all__ = [ "USYearData", "PolicyEngineUSDataset", + "create_datasets", "PolicyEngineUS", "PolicyEngineUSLatest", "us_model", diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py index 4027feeb..b6f5219d 100644 --- a/src/policyengine/tax_benefit_models/us/datasets.py +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -199,3 +199,183 @@ def __repr__(self) -> str: n_tax_units = len(self.data.tax_unit) n_households = len(self.data.household) return f"" + + +def create_datasets( + datasets: list[str] = [ + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5", + ], + years: list[int] = [2024, 2025, 2026, 2027, 2028], +) -> None: + """Create PolicyEngineUSDataset instances from HuggingFace dataset paths. + + Args: + datasets: List of HuggingFace dataset paths (e.g., "hf://policyengine/policyengine-us-data/cps_2024.h5") + years: List of years to extract data for + """ + from policyengine_us import Microsimulation + + for dataset in datasets: + sim = Microsimulation(dataset=dataset) + + for year in years: + # Get all input variables from the simulation + # We'll calculate each input variable for the specified year + entity_data = { + "person": {}, + "household": {}, + "marital_unit": {}, + "family": {}, + "spm_unit": {}, + "tax_unit": {}, + } + + # First, get ID columns which are structural (not input variables) + # These define entity membership and relationships + # For person-level links to group entities, use person_X_id naming + id_variables = { + "person": [ + "person_id", + "person_household_id", + "person_marital_unit_id", + "person_family_id", + "person_spm_unit_id", + "person_tax_unit_id", + ], + "household": ["household_id"], + "marital_unit": ["marital_unit_id"], + "family": ["family_id"], + "spm_unit": ["spm_unit_id"], + "tax_unit": ["tax_unit_id"], + } + + for entity_key, var_names in id_variables.items(): + for id_var in var_names: + if id_var in sim.tax_benefit_system.variables: + values = sim.calculate(id_var, period=year).values + entity_data[entity_key][id_var] = values + + # Get input variables and calculate them for this year + for variable_name in sim.input_variables: + variable = sim.tax_benefit_system.variables[variable_name] + entity_key = variable.entity.key + + # Calculate the variable for the given year + values = sim.calculate(variable_name, period=year).values + + # Store in the appropriate entity dictionary + entity_data[entity_key][variable_name] = values + + # Build entity DataFrames + person_df = pd.DataFrame(entity_data["person"]) + household_df = pd.DataFrame(entity_data["household"]) + marital_unit_df = pd.DataFrame(entity_data["marital_unit"]) + family_df = pd.DataFrame(entity_data["family"]) + spm_unit_df = pd.DataFrame(entity_data["spm_unit"]) + tax_unit_df = pd.DataFrame(entity_data["tax_unit"]) + + # Add weight columns - household weights are primary, map to all entities + # Person weights = household weights (mapped via person_household_id) + if "household_weight" in household_df.columns: + # Only add person_weight if it doesn't already exist + if "person_weight" not in person_df.columns: + person_df = person_df.merge( + household_df[["household_id", "household_weight"]], + left_on="person_household_id", + right_on="household_id", + how="left", + ) + person_df = person_df.rename( + columns={"household_weight": "person_weight"} + ) + person_df = person_df.drop( + columns=["household_id"], errors="ignore" + ) + + # Map household weights to other group entities via person table + for entity_name, entity_df, person_id_col, entity_id_col in [ + ( + "marital_unit", + marital_unit_df, + "person_marital_unit_id", + "marital_unit_id", + ), + ("family", family_df, "person_family_id", "family_id"), + ( + "spm_unit", + spm_unit_df, + "person_spm_unit_id", + "spm_unit_id", + ), + ( + "tax_unit", + tax_unit_df, + "person_tax_unit_id", + "tax_unit_id", + ), + ]: + # Only add entity weight if it doesn't already exist + if f"{entity_name}_weight" not in entity_df.columns: + # Get household_id for each entity from person table + entity_household_map = person_df[ + [person_id_col, "person_household_id"] + ].drop_duplicates() + entity_df = entity_df.merge( + entity_household_map, + left_on=entity_id_col, + right_on=person_id_col, + how="left", + ) + entity_df = entity_df.merge( + household_df[["household_id", "household_weight"]], + left_on="person_household_id", + right_on="household_id", + how="left", + ) + entity_df = entity_df.rename( + columns={ + "household_weight": f"{entity_name}_weight" + } + ) + entity_df = entity_df.drop( + columns=[ + "household_id", + "person_household_id", + person_id_col, + ], + errors="ignore", + ) + + # Update the entity_data + if entity_name == "marital_unit": + marital_unit_df = entity_df + elif entity_name == "family": + family_df = entity_df + elif entity_name == "spm_unit": + spm_unit_df = entity_df + elif entity_name == "tax_unit": + tax_unit_df = entity_df + + us_dataset = PolicyEngineUSDataset( + name=f"{dataset}-year-{year}", + description=f"US Dataset for year {year} based on {dataset}", + filepath=f"./data/{Path(dataset).stem}_year_{year}.h5", + year=year, + data=USYearData( + person=MicroDataFrame(person_df, weights="person_weight"), + household=MicroDataFrame( + household_df, weights="household_weight" + ), + marital_unit=MicroDataFrame( + marital_unit_df, weights="marital_unit_weight" + ), + family=MicroDataFrame(family_df, weights="family_weight"), + spm_unit=MicroDataFrame( + spm_unit_df, weights="spm_unit_weight" + ), + tax_unit=MicroDataFrame( + tax_unit_df, weights="tax_unit_weight" + ), + ), + ) + us_dataset.save() diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index 34a3528d..fda04ad6 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -173,22 +173,13 @@ def run(self, simulation: "Simulation") -> "Simulation": "person_weight", # Demographics "age", - "gender", - "is_adult", - "is_child", # Income "employment_income", - "self_employment_income", - "pension_income", - "social_security", - "ssi", # Benefits - "snap", - "tanf", - "medicare", + "ssi", + "social_security", "medicaid", - # Tax - "payroll_tax", + "unemployment_compensation", ], "marital_unit": [ "marital_unit_id", @@ -212,7 +203,6 @@ def run(self, simulation: "Simulation") -> "Simulation": "employee_payroll_tax", "eitc", "ctc", - "adjusted_gross_income", ], "household": [ "household_id", @@ -302,50 +292,78 @@ def _build_simulation_from_dataset(self, microsim, dataset, system): # Extract entity IDs from dataset person_data = pd.DataFrame(dataset.data.person) + # Determine column naming convention + # Support both person_X_id (from create_datasets) and X_id (from custom datasets) + household_id_col = ( + "person_household_id" + if "person_household_id" in person_data.columns + else "household_id" + ) + marital_unit_id_col = ( + "person_marital_unit_id" + if "person_marital_unit_id" in person_data.columns + else "marital_unit_id" + ) + family_id_col = ( + "person_family_id" + if "person_family_id" in person_data.columns + else "family_id" + ) + spm_unit_id_col = ( + "person_spm_unit_id" + if "person_spm_unit_id" in person_data.columns + else "spm_unit_id" + ) + tax_unit_id_col = ( + "person_tax_unit_id" + if "person_tax_unit_id" in person_data.columns + else "tax_unit_id" + ) + # Declare entities builder.declare_person_entity( "person", person_data["person_id"].values ) builder.declare_entity( - "household", np.unique(person_data["household_id"].values) + "household", np.unique(person_data[household_id_col].values) ) builder.declare_entity( - "spm_unit", np.unique(person_data["spm_unit_id"].values) + "spm_unit", np.unique(person_data[spm_unit_id_col].values) ) builder.declare_entity( - "family", np.unique(person_data["family_id"].values) + "family", np.unique(person_data[family_id_col].values) ) builder.declare_entity( - "tax_unit", np.unique(person_data["tax_unit_id"].values) + "tax_unit", np.unique(person_data[tax_unit_id_col].values) ) builder.declare_entity( - "marital_unit", np.unique(person_data["marital_unit_id"].values) + "marital_unit", np.unique(person_data[marital_unit_id_col].values) ) # Join persons to group entities builder.join_with_persons( builder.populations["household"], - person_data["household_id"].values, + person_data[household_id_col].values, np.array(["member"] * len(person_data)), ) builder.join_with_persons( builder.populations["spm_unit"], - person_data["spm_unit_id"].values, + person_data[spm_unit_id_col].values, np.array(["member"] * len(person_data)), ) builder.join_with_persons( builder.populations["family"], - person_data["family_id"].values, + person_data[family_id_col].values, np.array(["member"] * len(person_data)), ) builder.join_with_persons( builder.populations["tax_unit"], - person_data["tax_unit_id"].values, + person_data[tax_unit_id_col].values, np.array(["member"] * len(person_data)), ) builder.join_with_persons( builder.populations["marital_unit"], - person_data["marital_unit_id"].values, + person_data[marital_unit_id_col].values, np.array(["member"] * len(person_data)), ) @@ -354,13 +372,19 @@ def _build_simulation_from_dataset(self, microsim, dataset, system): # Set input variables for each entity # Skip ID columns as they're structural and already used in entity building + # Support both naming conventions id_columns = { "person_id", "household_id", + "person_household_id", "spm_unit_id", + "person_spm_unit_id", "family_id", + "person_family_id", "tax_unit_id", + "person_tax_unit_id", "marital_unit_id", + "person_marital_unit_id", } for entity_name, entity_df in [ diff --git a/tests/test_us_datasets.py b/tests/test_us_datasets.py new file mode 100644 index 00000000..6f84c507 --- /dev/null +++ b/tests/test_us_datasets.py @@ -0,0 +1,108 @@ +"""Tests for US dataset creation from HuggingFace paths.""" + +import pytest +import pandas as pd +from pathlib import Path +import shutil +from policyengine.tax_benefit_models.us import ( + create_datasets, + PolicyEngineUSDataset, +) + + +def test_create_datasets_from_enhanced_cps(): + """Test creating datasets from enhanced CPS HuggingFace path.""" + # Clean up data directory if it exists + data_dir = Path("./data") + if data_dir.exists(): + shutil.rmtree(data_dir) + + # Create datasets for a single year to test + datasets = ["hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5"] + years = [2024] + + create_datasets(datasets=datasets, years=years) + + # Verify the dataset was created + dataset_file = data_dir / "enhanced_cps_2024_year_2024.h5" + assert dataset_file.exists(), f"Dataset file {dataset_file} should exist" + + # Load and verify dataset structure + dataset = PolicyEngineUSDataset( + name="test", + description="test", + filepath=str(dataset_file), + year=2024, + ) + dataset.load() + + # Check all entity types exist + assert dataset.data is not None + assert dataset.data.person is not None + assert dataset.data.household is not None + assert dataset.data.marital_unit is not None + assert dataset.data.family is not None + assert dataset.data.spm_unit is not None + assert dataset.data.tax_unit is not None + + # Check person data has required columns + person_df = pd.DataFrame(dataset.data.person) + assert "person_id" in person_df.columns + assert "person_household_id" in person_df.columns + assert "person_weight" in person_df.columns + assert len(person_df) > 0 + + # Check household data + household_df = pd.DataFrame(dataset.data.household) + assert "household_id" in household_df.columns + assert "household_weight" in household_df.columns + assert len(household_df) > 0 + + # Check all group entities have weight columns + for entity_name in [ + "marital_unit", + "family", + "spm_unit", + "tax_unit", + ]: + entity_df = pd.DataFrame(getattr(dataset.data, entity_name)) + assert f"{entity_name}_id" in entity_df.columns + assert f"{entity_name}_weight" in entity_df.columns + assert len(entity_df) > 0 + + # Clean up + shutil.rmtree(data_dir) + + +def test_create_datasets_multiple_years(): + """Test creating datasets for multiple years.""" + # Clean up data directory if it exists + data_dir = Path("./data") + if data_dir.exists(): + shutil.rmtree(data_dir) + + datasets = ["hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5"] + years = [2024, 2025] + + create_datasets(datasets=datasets, years=years) + + # Verify both year datasets were created + for year in years: + dataset_file = data_dir / f"enhanced_cps_2024_year_{year}.h5" + assert dataset_file.exists(), ( + f"Dataset file for year {year} should exist" + ) + + # Load and verify + dataset = PolicyEngineUSDataset( + name=f"test-{year}", + description=f"test {year}", + filepath=str(dataset_file), + year=year, + ) + dataset.load() + assert dataset.data is not None + assert len(pd.DataFrame(dataset.data.person)) > 0 + + # Clean up + shutil.rmtree(data_dir) From be1646bd7f3e0e58d3bbff652fbf08ba582f1111 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 16 Nov 2025 23:08:46 +0000 Subject: [PATCH 27/35] Standardise --- examples/income_distribution_us.py | 46 +++ examples/speedtest_us_simulation.py | 291 ++++++++++++++++++ src/policyengine/outputs/aggregate.py | 15 +- .../tax_benefit_models/us/datasets.py | 82 ++++- 4 files changed, 413 insertions(+), 21 deletions(-) create mode 100644 examples/speedtest_us_simulation.py diff --git a/examples/income_distribution_us.py b/examples/income_distribution_us.py index 02173845..20803ef7 100644 --- a/examples/income_distribution_us.py +++ b/examples/income_distribution_us.py @@ -10,6 +10,7 @@ """ from pathlib import Path +import time import plotly.graph_objects as go from plotly.subplots import make_subplots from policyengine.core import Simulation @@ -55,6 +56,7 @@ def run_simulation(dataset: PolicyEngineUSDataset) -> Simulation: def calculate_income_decile_statistics(simulation: Simulation) -> dict: """Calculate total income, tax, and benefits by income deciles.""" + start_time = time.time() deciles = [f"D{i}" for i in range(1, 11)] market_incomes = [] taxes = [] @@ -63,7 +65,13 @@ def calculate_income_decile_statistics(simulation: Simulation) -> dict: counts = [] # Calculate household-level aggregates by decile + print("Calculating main statistics by decile...") + main_stats_start = time.time() for decile_num in range(1, 11): + decile_start = time.time() + + # Market income + pre_create = time.time() agg = Aggregate( simulation=simulation, variable="household_market_income", @@ -72,7 +80,12 @@ def calculate_income_decile_statistics(simulation: Simulation) -> dict: quantile=10, quantile_eq=decile_num, ) + if decile_num == 1: + print(f" First Aggregate created ({time.time() - pre_create:.2f}s)") + pre_run = time.time() agg.run() + if decile_num == 1: + print(f" First Aggregate.run() complete ({time.time() - pre_run:.2f}s)") market_incomes.append(agg.result / 1e9) agg = Aggregate( @@ -119,18 +132,28 @@ def calculate_income_decile_statistics(simulation: Simulation) -> dict: agg.run() counts.append(agg.result / 1e6) + print(f" D{decile_num} complete ({time.time() - decile_start:.2f}s)") + + print(f"Main statistics complete ({time.time() - main_stats_start:.2f}s)") + # Calculate individual benefit programs by decile benefit_programs_by_decile = {} # Person-level benefits (mapped to household for decile filtering) + print("Calculating person-level benefit programs...") + person_benefits_start = time.time() + first_prog = True for prog in [ "ssi", "social_security", "medicaid", "unemployment_compensation", ]: + prog_start = time.time() prog_by_decile = [] for decile_num in range(1, 11): + if first_prog and decile_num == 1: + pre_create = time.time() agg = Aggregate( simulation=simulation, variable=prog, @@ -139,13 +162,26 @@ def calculate_income_decile_statistics(simulation: Simulation) -> dict: filter_variable="household_net_income", quantile=10, quantile_eq=decile_num, + debug_timing=first_prog and decile_num == 1, ) + if first_prog and decile_num == 1: + print(f" First benefit Aggregate created ({time.time() - pre_create:.2f}s)") + pre_run = time.time() agg.run() + if first_prog and decile_num == 1: + print(f" First benefit Aggregate.run() complete ({time.time() - pre_run:.2f}s)") + first_prog = False prog_by_decile.append(agg.result / 1e9) benefit_programs_by_decile[prog] = prog_by_decile + print(f" {prog} complete ({time.time() - prog_start:.2f}s)") + + print(f"Person-level benefits complete ({time.time() - person_benefits_start:.2f}s)") # SPM unit benefits (mapped to household for decile filtering) + print("Calculating SPM unit benefit programs...") + spm_benefits_start = time.time() for prog in ["snap", "tanf"]: + prog_start = time.time() prog_by_decile = [] for decile_num in range(1, 11): agg = Aggregate( @@ -160,9 +196,15 @@ def calculate_income_decile_statistics(simulation: Simulation) -> dict: agg.run() prog_by_decile.append(agg.result / 1e9) benefit_programs_by_decile[prog] = prog_by_decile + print(f" {prog} complete ({time.time() - prog_start:.2f}s)") + + print(f"SPM benefits complete ({time.time() - spm_benefits_start:.2f}s)") # Tax unit benefits (mapped to household for decile filtering) + print("Calculating tax unit benefit programs...") + tax_benefits_start = time.time() for prog in ["eitc", "ctc"]: + prog_start = time.time() prog_by_decile = [] for decile_num in range(1, 11): agg = Aggregate( @@ -177,6 +219,10 @@ def calculate_income_decile_statistics(simulation: Simulation) -> dict: agg.run() prog_by_decile.append(agg.result / 1e9) benefit_programs_by_decile[prog] = prog_by_decile + print(f" {prog} complete ({time.time() - prog_start:.2f}s)") + + print(f"Tax benefits complete ({time.time() - tax_benefits_start:.2f}s)") + print(f"\nTotal statistics calculation time: {time.time() - start_time:.2f}s") return { "deciles": deciles, diff --git a/examples/speedtest_us_simulation.py b/examples/speedtest_us_simulation.py new file mode 100644 index 00000000..755a17f9 --- /dev/null +++ b/examples/speedtest_us_simulation.py @@ -0,0 +1,291 @@ +"""Speedtest: US simulation performance with different dataset sizes. + +This script tests how simulation.run() performance scales with dataset size +by running simulations on random subsets of households. +""" + +from pathlib import Path +import time +import pandas as pd +from policyengine.core import Simulation +from policyengine.tax_benefit_models.us import ( + PolicyEngineUSDataset, + USYearData, + us_latest, +) +from microdf import MicroDataFrame + + +def create_subset_dataset( + original_dataset: PolicyEngineUSDataset, n_households: int +) -> PolicyEngineUSDataset: + """Create a random subset of the dataset with n_households and reindexed entity IDs.""" + # Get original data + household_df = pd.DataFrame(original_dataset.data.household).copy() + person_df = pd.DataFrame(original_dataset.data.person).copy() + marital_unit_df = pd.DataFrame(original_dataset.data.marital_unit).copy() + family_df = pd.DataFrame(original_dataset.data.family).copy() + spm_unit_df = pd.DataFrame(original_dataset.data.spm_unit).copy() + tax_unit_df = pd.DataFrame(original_dataset.data.tax_unit).copy() + + # Sample random households (use n as seed to get different samples for different sizes) + sampled_households = household_df.sample(n=n_households, random_state=n_households).copy() + sampled_household_ids = set(sampled_households["household_id"]) + + # Determine column naming convention + household_id_col = ( + "person_household_id" + if "person_household_id" in person_df.columns + else "household_id" + ) + marital_unit_id_col = ( + "person_marital_unit_id" + if "person_marital_unit_id" in person_df.columns + else "marital_unit_id" + ) + family_id_col = ( + "person_family_id" + if "person_family_id" in person_df.columns + else "family_id" + ) + spm_unit_id_col = ( + "person_spm_unit_id" + if "person_spm_unit_id" in person_df.columns + else "spm_unit_id" + ) + tax_unit_id_col = ( + "person_tax_unit_id" + if "person_tax_unit_id" in person_df.columns + else "tax_unit_id" + ) + + # Filter person table to only include people in sampled households + sampled_person = person_df[ + person_df[household_id_col].isin(sampled_household_ids) + ].copy() + + # Get IDs of group entities that have members in sampled households + sampled_marital_unit_ids = set(sampled_person[marital_unit_id_col].unique()) + sampled_family_ids = set(sampled_person[family_id_col].unique()) + sampled_spm_unit_ids = set(sampled_person[spm_unit_id_col].unique()) + sampled_tax_unit_ids = set(sampled_person[tax_unit_id_col].unique()) + + # Filter group entity tables + sampled_marital_unit = marital_unit_df[ + marital_unit_df["marital_unit_id"].isin(sampled_marital_unit_ids) + ].copy() + sampled_family = family_df[ + family_df["family_id"].isin(sampled_family_ids) + ].copy() + sampled_spm_unit = spm_unit_df[ + spm_unit_df["spm_unit_id"].isin(sampled_spm_unit_ids) + ].copy() + sampled_tax_unit = tax_unit_df[ + tax_unit_df["tax_unit_id"].isin(sampled_tax_unit_ids) + ].copy() + + # Create ID mappings to reindex to contiguous integers starting from 0 + household_id_map = { + old_id: new_id + for new_id, old_id in enumerate(sorted(sampled_household_ids)) + } + marital_unit_id_map = { + old_id: new_id + for new_id, old_id in enumerate(sorted(sampled_marital_unit_ids)) + } + family_id_map = { + old_id: new_id for new_id, old_id in enumerate(sorted(sampled_family_ids)) + } + spm_unit_id_map = { + old_id: new_id + for new_id, old_id in enumerate(sorted(sampled_spm_unit_ids)) + } + tax_unit_id_map = { + old_id: new_id + for new_id, old_id in enumerate(sorted(sampled_tax_unit_ids)) + } + person_id_map = { + old_id: new_id + for new_id, old_id in enumerate(sorted(sampled_person["person_id"])) + } + + # Reindex all entity IDs in household table + sampled_households["household_id"] = sampled_households["household_id"].map( + household_id_map + ) + + # Reindex all entity IDs in person table + sampled_person["person_id"] = sampled_person["person_id"].map(person_id_map) + sampled_person[household_id_col] = sampled_person[household_id_col].map( + household_id_map + ) + sampled_person[marital_unit_id_col] = sampled_person[marital_unit_id_col].map( + marital_unit_id_map + ) + sampled_person[family_id_col] = sampled_person[family_id_col].map(family_id_map) + sampled_person[spm_unit_id_col] = sampled_person[spm_unit_id_col].map( + spm_unit_id_map + ) + sampled_person[tax_unit_id_col] = sampled_person[tax_unit_id_col].map( + tax_unit_id_map + ) + + # Reindex group entity tables + sampled_marital_unit["marital_unit_id"] = sampled_marital_unit[ + "marital_unit_id" + ].map(marital_unit_id_map) + sampled_family["family_id"] = sampled_family["family_id"].map(family_id_map) + sampled_spm_unit["spm_unit_id"] = sampled_spm_unit["spm_unit_id"].map( + spm_unit_id_map + ) + sampled_tax_unit["tax_unit_id"] = sampled_tax_unit["tax_unit_id"].map( + tax_unit_id_map + ) + + # Sort by ID to ensure proper ordering + sampled_households = sampled_households.sort_values("household_id").reset_index( + drop=True + ) + sampled_person = sampled_person.sort_values("person_id").reset_index(drop=True) + sampled_marital_unit = sampled_marital_unit.sort_values( + "marital_unit_id" + ).reset_index(drop=True) + sampled_family = sampled_family.sort_values("family_id").reset_index(drop=True) + sampled_spm_unit = sampled_spm_unit.sort_values("spm_unit_id").reset_index( + drop=True + ) + sampled_tax_unit = sampled_tax_unit.sort_values("tax_unit_id").reset_index( + drop=True + ) + + # Create new dataset + subset_dataset = PolicyEngineUSDataset( + name=f"Subset {n_households} households", + description=f"Random subset of {n_households} households", + filepath=f"./data/subset_{n_households}_households.h5", + year=original_dataset.year, + data=USYearData( + person=MicroDataFrame(sampled_person, weights="person_weight"), + household=MicroDataFrame( + sampled_households, weights="household_weight" + ), + marital_unit=MicroDataFrame( + sampled_marital_unit, weights="marital_unit_weight" + ), + family=MicroDataFrame(sampled_family, weights="family_weight"), + spm_unit=MicroDataFrame( + sampled_spm_unit, weights="spm_unit_weight" + ), + tax_unit=MicroDataFrame(sampled_tax_unit, weights="tax_unit_weight"), + ), + ) + + return subset_dataset + + +def speedtest_simulation(dataset: PolicyEngineUSDataset) -> float: + """Run simulation and return execution time in seconds.""" + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=us_latest, + ) + + start_time = time.time() + simulation.run() + end_time = time.time() + + return end_time - start_time + + +def main(): + print("Loading full enhanced CPS dataset...") + dataset_path = Path(__file__).parent / "data" / "enhanced_cps_2024_year_2024.h5" + + if not dataset_path.exists(): + raise FileNotFoundError( + f"Dataset not found at {dataset_path}. " + "Run create_datasets() from policyengine.tax_benefit_models.us first." + ) + + full_dataset = PolicyEngineUSDataset( + name="Enhanced CPS 2024", + description="Full enhanced CPS dataset", + filepath=str(dataset_path), + year=2024, + ) + full_dataset.load() + + total_households = len(full_dataset.data.household) + print(f"Full dataset: {total_households:,} households") + + # Test different subset sizes + test_sizes = [100, 500, 1000, 2500, 5000, 10000, 21532] # Last is full size + + results = [] + + for n_households in test_sizes: + if n_households > total_households: + continue + + print(f"\nTesting {n_households:,} households...") + + if n_households == total_households: + subset = full_dataset + else: + subset = create_subset_dataset(full_dataset, n_households) + + n_people = len(subset.data.person) + print(f" {n_people:,} people in subset") + + duration = speedtest_simulation(subset) + print(f" Simulation completed in {duration:.2f}s") + + results.append( + { + "households": n_households, + "people": n_people, + "duration_seconds": duration, + "households_per_second": n_households / duration, + } + ) + + print("\n" + "=" * 60) + print("SPEEDTEST RESULTS") + print("=" * 60) + print( + f"{'Households':<12} {'People':<10} {'Duration':<12} {'HH/sec':<10}" + ) + print("-" * 60) + + for result in results: + print( + f"{result['households']:<12,} {result['people']:<10,} " + f"{result['duration_seconds']:<12.2f} {result['households_per_second']:<10.1f}" + ) + + # Calculate scaling characteristics + print("\n" + "=" * 60) + print("SCALING ANALYSIS") + print("=" * 60) + + if len(results) >= 2: + # Compare first and last results + first = results[0] + last = results[-1] + + size_ratio = last["households"] / first["households"] + time_ratio = last["duration_seconds"] / first["duration_seconds"] + + print(f"Dataset size increased {size_ratio:.1f}x") + print(f"Simulation time increased {time_ratio:.1f}x") + + if time_ratio < size_ratio * 1.2: + print("Scaling: approximately linear or better") + elif time_ratio < size_ratio * 2: + print("Scaling: slightly worse than linear") + else: + print("Scaling: significantly worse than linear") + + +if __name__ == "__main__": + main() diff --git a/src/policyengine/outputs/aggregate.py b/src/policyengine/outputs/aggregate.py index 29f18138..2619c9e4 100644 --- a/src/policyengine/outputs/aggregate.py +++ b/src/policyengine/outputs/aggregate.py @@ -34,7 +34,6 @@ class Aggregate(Output): ) result: Any | None = None - def run(self): # Convert quantile specification to describes_quantiles format if self.quantile is not None: @@ -66,7 +65,7 @@ def run(self): # Map variable to target entity if needed if var_obj.entity != target_entity: mapped = self.simulation.output_dataset.data.map_to_entity( - var_obj.entity, target_entity + var_obj.entity, target_entity, columns=[self.variable] ) series = mapped[self.variable] else: @@ -83,7 +82,7 @@ def run(self): if filter_var_obj.entity != target_entity: filter_mapped = ( self.simulation.output_dataset.data.map_to_entity( - filter_var_obj.entity, target_entity + filter_var_obj.entity, target_entity, columns=[self.filter_variable] ) ) filter_series = filter_mapped[self.filter_variable] @@ -95,14 +94,10 @@ def run(self): threshold = filter_series.quantile(self.filter_variable_eq) series = series[filter_series <= threshold] if self.filter_variable_leq is not None: - threshold = filter_series.quantile( - self.filter_variable_leq - ) + threshold = filter_series.quantile(self.filter_variable_leq) series = series[filter_series <= threshold] if self.filter_variable_geq is not None: - threshold = filter_series.quantile( - self.filter_variable_geq - ) + threshold = filter_series.quantile(self.filter_variable_geq) series = series[filter_series >= threshold] else: if self.filter_variable_eq is not None: @@ -112,7 +107,7 @@ def run(self): if self.filter_variable_geq is not None: series = series[filter_series >= self.filter_variable_geq] - # Aggregate + # Aggregate - MicroSeries will automatically apply weights if self.aggregate_type == AggregateType.SUM: self.result = series.sum() elif self.aggregate_type == AggregateType.MEAN: diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py index b6f5219d..adf08f8b 100644 --- a/src/policyengine/tax_benefit_models/us/datasets.py +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -50,8 +50,11 @@ def map_to_entity( f"Invalid target entity '{target_entity}'. Must be one of {valid_entities}" ) - # Get source data + # Get source data (as raw pandas DataFrame, not MicroDataFrame) source_df = getattr(self, source_entity) + # Convert MicroDataFrame to plain pandas DataFrame to avoid weighted values + source_df = pd.DataFrame(source_df) + if columns: # Select only requested columns (keep join keys) join_keys = { @@ -90,19 +93,51 @@ def map_to_entity( # Person to group entity: aggregate person-level data to group level if source_entity == "person" and target_entity != "person": - if target_key in pd.DataFrame(source_df).columns: - # Merge source (person) with target (group) on target_key - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on=target_key, how="left" + if target_key in source_df.columns: + # Aggregate person-level data to household level first + source_pd = source_df + # Get columns to aggregate (exclude join keys and weights) + join_keys = { + "person_id", + "marital_unit_id", + "family_id", + "spm_unit_id", + "tax_unit_id", + "household_id", + } + # Also exclude weight columns + weight_cols = { + "person_weight", + "marital_unit_weight", + "family_weight", + "spm_unit_weight", + "tax_unit_weight", + "household_weight", + } + agg_cols = [c for c in source_pd.columns if c not in join_keys and c not in weight_cols] + # Group by target key and sum + aggregated = source_pd.groupby(target_key, as_index=False)[agg_cols].sum() + # Merge with target, preserving original order + target_pd = pd.DataFrame(target_df)[[target_key, target_weight]] + # Add index to preserve order + target_pd = target_pd.reset_index(drop=False) + result = target_pd.merge( + aggregated, on=target_key, how="left" ) + # Sort back to original order + result = result.sort_values('index').drop('index', axis=1).reset_index(drop=True) + # Fill NaN with 0 for households with no members in source entity + result[agg_cols] = result[agg_cols].fillna(0) + # Return MicroDataFrame with proper weights return MicroDataFrame(result, weights=target_weight) # Group entity to person: expand group-level data to person level if source_entity != "person" and target_entity == "person": source_key = f"{source_entity}_id" - if source_key in pd.DataFrame(target_df).columns: - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on=source_key, how="left" + target_pd = pd.DataFrame(target_df) + if source_key in target_pd.columns: + result = target_pd.merge( + source_df, on=source_key, how="left" ) return MicroDataFrame(result, weights=target_weight) @@ -120,12 +155,37 @@ def map_to_entity( person_link = person_df[ [source_key, target_key] ].drop_duplicates() - source_with_target = pd.DataFrame(source_df).merge( + # Join source data with target key + source_with_target = source_df.merge( person_link, on=source_key, how="left" ) - result = pd.DataFrame(target_df).merge( - source_with_target, on=target_key, how="left" + # Aggregate to target level + join_keys_all = { + "person_id", "marital_unit_id", "family_id", + "spm_unit_id", "tax_unit_id", "household_id", + } + weight_cols = { + "person_weight", + "marital_unit_weight", + "family_weight", + "spm_unit_weight", + "tax_unit_weight", + "household_weight", + } + agg_cols = [c for c in source_with_target.columns if c not in join_keys_all and c not in weight_cols] + aggregated = source_with_target.groupby(target_key, as_index=False)[agg_cols].sum() + # Merge with target, preserving original order + target_pd = pd.DataFrame(target_df)[[target_key, target_weight]] + # Add index to preserve order + target_pd = target_pd.reset_index(drop=False) + result = target_pd.merge( + aggregated, on=target_key, how="left" ) + # Sort back to original order + result = result.sort_values('index').drop('index', axis=1).reset_index(drop=True) + # Fill NaN with 0 + result[agg_cols] = result[agg_cols].fillna(0) + # Return MicroDataFrame with proper weights return MicroDataFrame(result, weights=target_weight) raise ValueError( From d569d585e667a23edefc51a2bb1c790636ca4afc Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 16 Nov 2025 23:37:57 +0000 Subject: [PATCH 28/35] Update pkg --- examples/income_distribution_us.py | 75 ++-- src/policyengine/core/__init__.py | 2 +- src/policyengine/core/dataset.py | 225 ++++++++++++ .../tax_benefit_models/uk/datasets.py | 109 +----- .../tax_benefit_models/us/datasets.py | 171 +-------- .../tax_benefit_models/us/model.py | 48 ++- tests/test_entity_mapping.py | 93 +++-- tests/test_us_entity_mapping.py | 333 ++++++++++++++++++ 8 files changed, 748 insertions(+), 308 deletions(-) create mode 100644 tests/test_us_entity_mapping.py diff --git a/examples/income_distribution_us.py b/examples/income_distribution_us.py index 20803ef7..fdbf48bf 100644 --- a/examples/income_distribution_us.py +++ b/examples/income_distribution_us.py @@ -81,11 +81,15 @@ def calculate_income_decile_statistics(simulation: Simulation) -> dict: quantile_eq=decile_num, ) if decile_num == 1: - print(f" First Aggregate created ({time.time() - pre_create:.2f}s)") + print( + f" First Aggregate created ({time.time() - pre_create:.2f}s)" + ) pre_run = time.time() agg.run() if decile_num == 1: - print(f" First Aggregate.run() complete ({time.time() - pre_run:.2f}s)") + print( + f" First Aggregate.run() complete ({time.time() - pre_run:.2f}s)" + ) market_incomes.append(agg.result / 1e9) agg = Aggregate( @@ -165,17 +169,23 @@ def calculate_income_decile_statistics(simulation: Simulation) -> dict: debug_timing=first_prog and decile_num == 1, ) if first_prog and decile_num == 1: - print(f" First benefit Aggregate created ({time.time() - pre_create:.2f}s)") + print( + f" First benefit Aggregate created ({time.time() - pre_create:.2f}s)" + ) pre_run = time.time() agg.run() if first_prog and decile_num == 1: - print(f" First benefit Aggregate.run() complete ({time.time() - pre_run:.2f}s)") + print( + f" First benefit Aggregate.run() complete ({time.time() - pre_run:.2f}s)" + ) first_prog = False prog_by_decile.append(agg.result / 1e9) benefit_programs_by_decile[prog] = prog_by_decile print(f" {prog} complete ({time.time() - prog_start:.2f}s)") - print(f"Person-level benefits complete ({time.time() - person_benefits_start:.2f}s)") + print( + f"Person-level benefits complete ({time.time() - person_benefits_start:.2f}s)" + ) # SPM unit benefits (mapped to household for decile filtering) print("Calculating SPM unit benefit programs...") @@ -222,7 +232,9 @@ def calculate_income_decile_statistics(simulation: Simulation) -> dict: print(f" {prog} complete ({time.time() - prog_start:.2f}s)") print(f"Tax benefits complete ({time.time() - tax_benefits_start:.2f}s)") - print(f"\nTotal statistics calculation time: {time.time() - start_time:.2f}s") + print( + f"\nTotal statistics calculation time: {time.time() - start_time:.2f}s" + ) return { "deciles": deciles, @@ -260,6 +272,7 @@ def visualise_results(results: dict) -> None: y=results["market_incomes"], marker_color=COLORS["primary"], name="Market income", + showlegend=False, ), row=1, col=1, @@ -272,30 +285,34 @@ def visualise_results(results: dict) -> None: y=results["taxes"], marker_color=COLORS["error"], name="Tax", + showlegend=False, ), row=1, col=2, ) - # Benefits by program (stacked) + # Benefits by program (stacked) - with legend benefit_programs = [ - ("Social Security", "social_security"), - ("Medicaid", "medicaid"), - ("SNAP", "snap"), - ("EITC", "eitc"), - ("CTC", "ctc"), - ("SSI", "ssi"), - ("TANF", "tanf"), - ("Unemployment", "unemployment_compensation"), + ("Social Security", "social_security", "#026AA2"), + ("Medicaid", "medicaid", "#319795"), + ("SNAP", "snap", "#22C55E"), + ("EITC", "eitc", "#FEC601"), + ("CTC", "ctc", "#1890FF"), + ("SSI", "ssi", "#EF4444"), + ("TANF", "tanf", "#667085"), + ("Unemployment", "unemployment_compensation", "#101828"), ] - for name, key in benefit_programs: + for name, key, color in benefit_programs: if key in results["benefit_programs_by_decile"]: fig.add_trace( go.Bar( x=results["deciles"], y=results["benefit_programs_by_decile"][key], name=name, + marker_color=color, + legendgroup="benefits", + showlegend=True, ), row=2, col=1, @@ -308,6 +325,7 @@ def visualise_results(results: dict) -> None: y=results["counts"], marker_color=COLORS["info"], name="Households", + showlegend=False, ), row=2, col=2, @@ -318,13 +336,28 @@ def visualise_results(results: dict) -> None: fig.update_xaxes(title_text="Income decile", row=2, col=1) fig.update_xaxes(title_text="Income decile", row=2, col=2) + # Apply PolicyEngine formatting + format_fig( + fig, + title="US household income distribution (Enhanced CPS 2024)", + show_legend=True, + height=800, + width=1400, + ) + + # Override legend position for subplot layout fig.update_layout( - title_text="US household income distribution (Enhanced CPS 2024)", - showlegend=True, barmode="stack", - height=800, - width=1200, - legend=dict(orientation="h", yanchor="bottom", y=-0.15, xanchor="center", x=0.5), + legend=dict( + orientation="v", + yanchor="top", + y=0.45, + xanchor="left", + x=0.52, + bgcolor="white", + bordercolor="#E5E7EB", + borderwidth=1, + ), ) fig.show() diff --git a/src/policyengine/core/__init__.py b/src/policyengine/core/__init__.py index 58372d57..5ecdcafb 100644 --- a/src/policyengine/core/__init__.py +++ b/src/policyengine/core/__init__.py @@ -1,5 +1,5 @@ from .variable import Variable -from .dataset import Dataset +from .dataset import Dataset, map_to_entity from .dynamic import Dynamic from .tax_benefit_model import TaxBenefitModel from .tax_benefit_model_version import TaxBenefitModelVersion diff --git a/src/policyengine/core/dataset.py b/src/policyengine/core/dataset.py index 5ae76fa4..d73ee71d 100644 --- a/src/policyengine/core/dataset.py +++ b/src/policyengine/core/dataset.py @@ -1,6 +1,8 @@ from uuid import uuid4 from pydantic import BaseModel, Field, ConfigDict +import pandas as pd +from microdf import MicroDataFrame from .tax_benefit_model import TaxBenefitModel from .dataset_version import DatasetVersion @@ -33,3 +35,226 @@ class MyDataset(Dataset): year: int data: BaseModel | None = None + + +def map_to_entity( + entity_data: dict[str, MicroDataFrame], + source_entity: str, + target_entity: str, + person_entity: str = "person", + columns: list[str] | None = None, +) -> MicroDataFrame: + """Map data from source entity to target entity using join keys. + + This is a generic entity mapping utility that handles: + - Same entity mapping (returns as is) + - Person to group entity mapping (aggregates values) + - Group to person entity mapping (expands values) + - Group to group entity mapping (aggregates through person entity) + + Args: + entity_data: Dictionary mapping entity names to their MicroDataFrame data + source_entity: The source entity name + target_entity: The target entity name + person_entity: The name of the person entity (default "person") + columns: List of column names to map. If None, maps all columns + + Returns: + MicroDataFrame: The mapped data at the target entity level + + Raises: + ValueError: If source or target entity is invalid + """ + valid_entities = set(entity_data.keys()) + + if source_entity not in valid_entities: + raise ValueError( + f"Invalid source entity '{source_entity}'. Must be one of {valid_entities}" + ) + if target_entity not in valid_entities: + raise ValueError( + f"Invalid target entity '{target_entity}'. Must be one of {valid_entities}" + ) + + # Get source data (convert to plain DataFrame to avoid weighted operations during mapping) + source_df = pd.DataFrame(entity_data[source_entity]) + + if columns: + # Select only requested columns (keep all ID columns for joins) + id_cols = {col for col in source_df.columns if col.endswith("_id")} + cols_to_keep = list(set(columns) | id_cols) + source_df = source_df[cols_to_keep] + + # Determine weight column for target entity + target_weight = f"{target_entity}_weight" + + # Same entity - return as is + if source_entity == target_entity: + return MicroDataFrame(source_df, weights=target_weight) + + # Get target data and key + target_df = entity_data[target_entity] + target_key = f"{target_entity}_id" + + # Person to group entity: aggregate person-level data to group level + if source_entity == person_entity and target_entity != person_entity: + # Check for both naming patterns: "entity_id" and "person_entity_id" + person_target_key = f"{person_entity}_{target_entity}_id" + join_key = ( + person_target_key + if person_target_key in source_df.columns + else target_key + ) + + if join_key in source_df.columns: + # Get columns to aggregate (exclude ID and weight columns) + id_cols = {col for col in source_df.columns if col.endswith("_id")} + weight_cols = { + col for col in source_df.columns if col.endswith("_weight") + } + agg_cols = [ + c + for c in source_df.columns + if c not in id_cols and c not in weight_cols + ] + + # Group by join key and sum + aggregated = source_df.groupby(join_key, as_index=False)[ + agg_cols + ].sum() + + # Rename join key to target key if needed + if join_key != target_key: + aggregated = aggregated.rename(columns={join_key: target_key}) + + # Merge with target, preserving original order + target_pd = pd.DataFrame(target_df)[[target_key, target_weight]] + target_pd = target_pd.reset_index(drop=False) + result = target_pd.merge(aggregated, on=target_key, how="left") + + # Sort back to original order + result = ( + result.sort_values("index") + .drop("index", axis=1) + .reset_index(drop=True) + ) + + # Fill NaN with 0 for groups with no members in source entity + result[agg_cols] = result[agg_cols].fillna(0) + + return MicroDataFrame(result, weights=target_weight) + + # Group entity to person: expand group-level data to person level + if source_entity != person_entity and target_entity == person_entity: + source_key = f"{source_entity}_id" + # Check for both naming patterns + person_source_key = f"{person_entity}_{source_entity}_id" + + target_pd = pd.DataFrame(target_df) + join_key = ( + person_source_key + if person_source_key in target_pd.columns + else source_key + ) + + if join_key in target_pd.columns: + # Rename source key to match join key if needed + if join_key != source_key and source_key in source_df.columns: + source_df = source_df.rename(columns={source_key: join_key}) + + result = target_pd.merge(source_df, on=join_key, how="left") + return MicroDataFrame(result, weights=target_weight) + + # Group to group: go through person table + if source_entity != person_entity and target_entity != person_entity: + # Get person link table with both entity IDs + person_df = pd.DataFrame(entity_data[person_entity]) + source_key = f"{source_entity}_id" + + # Check for both naming patterns for person-level links + person_source_key = f"{person_entity}_{source_entity}_id" + person_target_key = f"{person_entity}_{target_entity}_id" + + # Determine which keys exist in person table + source_link_key = ( + person_source_key + if person_source_key in person_df.columns + else source_key + ) + target_link_key = ( + person_target_key + if person_target_key in person_df.columns + else target_key + ) + + # Link source -> person -> target + if ( + source_link_key in person_df.columns + and target_link_key in person_df.columns + ): + person_link = person_df[ + [source_link_key, target_link_key] + ].drop_duplicates() + + # Rename source key to match link key if needed + source_df_copy = source_df.copy() + if ( + source_link_key != source_key + and source_key in source_df_copy.columns + ): + source_df_copy = source_df_copy.rename( + columns={source_key: source_link_key} + ) + + # Join source data with target key + source_with_target = source_df_copy.merge( + person_link, on=source_link_key, how="left" + ) + + # Aggregate to target level + id_cols = { + col + for col in source_with_target.columns + if col.endswith("_id") + } + weight_cols = { + col + for col in source_with_target.columns + if col.endswith("_weight") + } + agg_cols = [ + c + for c in source_with_target.columns + if c not in id_cols and c not in weight_cols + ] + + aggregated = source_with_target.groupby( + target_link_key, as_index=False + )[agg_cols].sum() + + # Rename target link key to target key if needed + if target_link_key != target_key: + aggregated = aggregated.rename( + columns={target_link_key: target_key} + ) + + # Merge with target, preserving original order + target_pd = pd.DataFrame(target_df)[[target_key, target_weight]] + target_pd = target_pd.reset_index(drop=False) + result = target_pd.merge(aggregated, on=target_key, how="left") + + # Sort back to original order + result = ( + result.sort_values("index") + .drop("index", axis=1) + .reset_index(drop=True) + ) + + # Fill NaN with 0 + result[agg_cols] = result[agg_cols].fillna(0) + + return MicroDataFrame(result, weights=target_weight) + + raise ValueError( + f"Unsupported mapping from {source_entity} to {target_entity}" + ) diff --git a/src/policyengine/tax_benefit_models/uk/datasets.py b/src/policyengine/tax_benefit_models/uk/datasets.py index 309ef4bb..0b94e4ca 100644 --- a/src/policyengine/tax_benefit_models/uk/datasets.py +++ b/src/policyengine/tax_benefit_models/uk/datasets.py @@ -1,4 +1,4 @@ -from policyengine.core import Dataset +from policyengine.core import Dataset, map_to_entity from pydantic import BaseModel, ConfigDict import pandas as pd from microdf import MicroDataFrame @@ -30,103 +30,18 @@ def map_to_entity( Raises: ValueError: If source or target entity is invalid. """ - valid_entities = {"person", "benunit", "household"} - if source_entity not in valid_entities: - raise ValueError( - f"Invalid source entity '{source_entity}'. Must be one of {valid_entities}" - ) - if target_entity not in valid_entities: - raise ValueError( - f"Invalid target entity '{target_entity}'. Must be one of {valid_entities}" - ) - - # Get source data - source_df = getattr(self, source_entity) - if columns: - # Select only requested columns (keep join keys) - join_keys = {"person_id", "benunit_id", "household_id"} - cols_to_keep = list( - set(columns) | (join_keys & set(source_df.columns)) - ) - source_df = source_df[cols_to_keep] - - # Determine weight column for target entity - weight_col_map = { - "person": "person_weight", - "benunit": "benunit_weight", - "household": "household_weight", + entity_data = { + "person": self.person, + "benunit": self.benunit, + "household": self.household, } - target_weight = weight_col_map[target_entity] - - # Same entity - return as is - if source_entity == target_entity: - return MicroDataFrame( - pd.DataFrame(source_df), weights=target_weight - ) - - # Map to different entity - target_df = getattr(self, target_entity) - - # Person -> Benunit - if source_entity == "person" and target_entity == "benunit": - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on="benunit_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Person -> Household - elif source_entity == "person" and target_entity == "household": - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on="household_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Benunit -> Person - elif source_entity == "benunit" and target_entity == "person": - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on="benunit_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Benunit -> Household - elif source_entity == "benunit" and target_entity == "household": - # Need to go through person to link benunit and household - person_link = pd.DataFrame(self.person)[ - ["benunit_id", "household_id"] - ].drop_duplicates() - source_with_hh = pd.DataFrame(source_df).merge( - person_link, on="benunit_id", how="left" - ) - result = pd.DataFrame(target_df).merge( - source_with_hh, on="household_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Household -> Person - elif source_entity == "household" and target_entity == "person": - result = pd.DataFrame(target_df).merge( - pd.DataFrame(source_df), on="household_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Household -> Benunit - elif source_entity == "household" and target_entity == "benunit": - # Need to go through person to link household and benunit - person_link = pd.DataFrame(self.person)[ - ["benunit_id", "household_id"] - ].drop_duplicates() - source_with_bu = pd.DataFrame(source_df).merge( - person_link, on="household_id", how="left" - ) - result = pd.DataFrame(target_df).merge( - source_with_bu, on="benunit_id", how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - else: - raise ValueError( - f"Unsupported mapping from {source_entity} to {target_entity}" - ) + return map_to_entity( + entity_data=entity_data, + source_entity=source_entity, + target_entity=target_entity, + person_entity="person", + columns=columns, + ) class PolicyEngineUKDataset(Dataset): diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py index adf08f8b..b98497ee 100644 --- a/src/policyengine/tax_benefit_models/us/datasets.py +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -1,4 +1,4 @@ -from policyengine.core import Dataset +from policyengine.core import Dataset, map_to_entity from pydantic import BaseModel, ConfigDict import pandas as pd from microdf import MicroDataFrame @@ -33,163 +33,20 @@ def map_to_entity( Raises: ValueError: If source or target entity is invalid. """ - valid_entities = { - "person", - "marital_unit", - "family", - "spm_unit", - "tax_unit", - "household", + entity_data = { + "person": self.person, + "marital_unit": self.marital_unit, + "family": self.family, + "spm_unit": self.spm_unit, + "tax_unit": self.tax_unit, + "household": self.household, } - if source_entity not in valid_entities: - raise ValueError( - f"Invalid source entity '{source_entity}'. Must be one of {valid_entities}" - ) - if target_entity not in valid_entities: - raise ValueError( - f"Invalid target entity '{target_entity}'. Must be one of {valid_entities}" - ) - - # Get source data (as raw pandas DataFrame, not MicroDataFrame) - source_df = getattr(self, source_entity) - # Convert MicroDataFrame to plain pandas DataFrame to avoid weighted values - source_df = pd.DataFrame(source_df) - - if columns: - # Select only requested columns (keep join keys) - join_keys = { - "person_id", - "marital_unit_id", - "family_id", - "spm_unit_id", - "tax_unit_id", - "household_id", - } - cols_to_keep = list( - set(columns) | (join_keys & set(source_df.columns)) - ) - source_df = source_df[cols_to_keep] - - # Determine weight column for target entity - weight_col_map = { - "person": "person_weight", - "marital_unit": "marital_unit_weight", - "family": "family_weight", - "spm_unit": "spm_unit_weight", - "tax_unit": "tax_unit_weight", - "household": "household_weight", - } - target_weight = weight_col_map[target_entity] - - # Same entity - return as is - if source_entity == target_entity: - return MicroDataFrame( - pd.DataFrame(source_df), weights=target_weight - ) - - # Map to different entity - target_df = getattr(self, target_entity) - target_key = f"{target_entity}_id" - - # Person to group entity: aggregate person-level data to group level - if source_entity == "person" and target_entity != "person": - if target_key in source_df.columns: - # Aggregate person-level data to household level first - source_pd = source_df - # Get columns to aggregate (exclude join keys and weights) - join_keys = { - "person_id", - "marital_unit_id", - "family_id", - "spm_unit_id", - "tax_unit_id", - "household_id", - } - # Also exclude weight columns - weight_cols = { - "person_weight", - "marital_unit_weight", - "family_weight", - "spm_unit_weight", - "tax_unit_weight", - "household_weight", - } - agg_cols = [c for c in source_pd.columns if c not in join_keys and c not in weight_cols] - # Group by target key and sum - aggregated = source_pd.groupby(target_key, as_index=False)[agg_cols].sum() - # Merge with target, preserving original order - target_pd = pd.DataFrame(target_df)[[target_key, target_weight]] - # Add index to preserve order - target_pd = target_pd.reset_index(drop=False) - result = target_pd.merge( - aggregated, on=target_key, how="left" - ) - # Sort back to original order - result = result.sort_values('index').drop('index', axis=1).reset_index(drop=True) - # Fill NaN with 0 for households with no members in source entity - result[agg_cols] = result[agg_cols].fillna(0) - # Return MicroDataFrame with proper weights - return MicroDataFrame(result, weights=target_weight) - - # Group entity to person: expand group-level data to person level - if source_entity != "person" and target_entity == "person": - source_key = f"{source_entity}_id" - target_pd = pd.DataFrame(target_df) - if source_key in target_pd.columns: - result = target_pd.merge( - source_df, on=source_key, how="left" - ) - return MicroDataFrame(result, weights=target_weight) - - # Group to group: go through person table - if source_entity != "person" and target_entity != "person": - # Get person link table with both entity IDs - person_df = pd.DataFrame(self.person) - source_key = f"{source_entity}_id" - - # Link source -> person -> target - if ( - source_key in person_df.columns - and target_key in person_df.columns - ): - person_link = person_df[ - [source_key, target_key] - ].drop_duplicates() - # Join source data with target key - source_with_target = source_df.merge( - person_link, on=source_key, how="left" - ) - # Aggregate to target level - join_keys_all = { - "person_id", "marital_unit_id", "family_id", - "spm_unit_id", "tax_unit_id", "household_id", - } - weight_cols = { - "person_weight", - "marital_unit_weight", - "family_weight", - "spm_unit_weight", - "tax_unit_weight", - "household_weight", - } - agg_cols = [c for c in source_with_target.columns if c not in join_keys_all and c not in weight_cols] - aggregated = source_with_target.groupby(target_key, as_index=False)[agg_cols].sum() - # Merge with target, preserving original order - target_pd = pd.DataFrame(target_df)[[target_key, target_weight]] - # Add index to preserve order - target_pd = target_pd.reset_index(drop=False) - result = target_pd.merge( - aggregated, on=target_key, how="left" - ) - # Sort back to original order - result = result.sort_values('index').drop('index', axis=1).reset_index(drop=True) - # Fill NaN with 0 - result[agg_cols] = result[agg_cols].fillna(0) - # Return MicroDataFrame with proper weights - return MicroDataFrame(result, weights=target_weight) - - raise ValueError( - f"Unsupported mapping from {source_entity} to {target_entity}" + return map_to_entity( + entity_data=entity_data, + source_entity=source_entity, + target_entity=target_entity, + person_entity="person", + columns=columns, ) diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index fda04ad6..bb1747f9 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -223,11 +223,53 @@ def run(self, simulation: "Simulation") -> "Simulation": "household": pd.DataFrame(), } + # ID columns should be preserved from input dataset, not calculated + id_columns = { + "person_id", + "household_id", + "marital_unit_id", + "family_id", + "spm_unit_id", + "tax_unit_id", + } + weight_columns = { + "person_weight", + "household_weight", + "marital_unit_weight", + "family_weight", + "spm_unit_weight", + "tax_unit_weight", + } + + # First, copy ID and weight columns from input dataset + for entity in data.keys(): + input_df = pd.DataFrame(getattr(dataset.data, entity)) + entity_id_col = f"{entity}_id" + entity_weight_col = f"{entity}_weight" + + if entity_id_col in input_df.columns: + data[entity][entity_id_col] = input_df[entity_id_col].values + if entity_weight_col in input_df.columns: + data[entity][entity_weight_col] = input_df[ + entity_weight_col + ].values + + # For person entity, also copy person-level group ID columns + person_input_df = pd.DataFrame(dataset.data.person) + for col in person_input_df.columns: + if col.startswith("person_") and col.endswith("_id"): + # Map person_household_id -> household_id, etc. + target_col = col.replace("person_", "") + if target_col in id_columns: + data["person"][target_col] = person_input_df[col].values + + # Then calculate non-ID, non-weight variables from simulation for entity, variables in entity_variables.items(): for var in variables: - data[entity][var] = microsim.calculate( - var, period=simulation.dataset.year, map_to=entity - ).values + if var not in id_columns and var not in weight_columns: + data[entity][var] = microsim.calculate( + var, period=simulation.dataset.year, map_to=entity + ).values data["person"] = MicroDataFrame( data["person"], weights="person_weight" diff --git a/tests/test_entity_mapping.py b/tests/test_entity_mapping.py index ffbecacd..48c39e3d 100644 --- a/tests/test_entity_mapping.py +++ b/tests/test_entity_mapping.py @@ -53,7 +53,7 @@ def test_map_same_entity(): def test_map_person_to_benunit(): - """Test mapping person-level data to benunit level.""" + """Test mapping person-level data to benunit level aggregates correctly.""" person_df = MicroDataFrame( pd.DataFrame( { @@ -82,19 +82,24 @@ def test_map_person_to_benunit(): person=person_df, benunit=benunit_df, household=household_df ) - result = data.map_to_entity("person", "benunit") + result = data.map_to_entity("person", "benunit", columns=["income"]) # Should return a MicroDataFrame assert isinstance(result, MicroDataFrame) - # Should have rows for each person - assert len(result) == 3 - # Should have benunit data merged in + # Should have rows for each benunit (aggregated) + assert len(result) == 2 + # Should have benunit data with aggregated income assert "benunit_id" in result.columns - assert "person_id" in result.columns + assert "income" in result.columns + + # Income should be aggregated (summed) at benunit level + benunit_incomes = result.set_index("benunit_id")["income"].to_dict() + assert benunit_incomes[1] == 80000 # 50000 + 30000 + assert benunit_incomes[2] == 60000 # 60000 def test_map_person_to_household(): - """Test mapping person-level data to household level.""" + """Test mapping person-level data to household level aggregates correctly.""" person_df = MicroDataFrame( pd.DataFrame( { @@ -102,6 +107,7 @@ def test_map_person_to_household(): "benunit_id": [1, 1, 2], "household_id": [1, 1, 2], "age": [30, 25, 40], + "income": [50000, 30000, 60000], "person_weight": [1.0, 1.0, 1.0], } ), @@ -128,18 +134,22 @@ def test_map_person_to_household(): person=person_df, benunit=benunit_df, household=household_df ) - result = data.map_to_entity("person", "household") + result = data.map_to_entity("person", "household", columns=["income"]) - # Should have rows for each person - assert len(result) == 3 - # Should have household data merged in + # Should have rows for each household (aggregated) + assert len(result) == 2 + # Should have household data with aggregated income assert "household_id" in result.columns - assert "person_id" in result.columns - assert "rent" in result.columns + assert "income" in result.columns + + # Income should be aggregated (summed) at household level + household_incomes = result.set_index("household_id")["income"].to_dict() + assert household_incomes[1] == 80000 # 50000 + 30000 + assert household_incomes[2] == 60000 # 60000 def test_map_benunit_to_person(): - """Test mapping benunit-level data to person level.""" + """Test mapping benunit-level data to person level expands correctly.""" person_df = MicroDataFrame( pd.DataFrame( { @@ -172,23 +182,29 @@ def test_map_benunit_to_person(): person=person_df, benunit=benunit_df, household=household_df ) - result = data.map_to_entity("benunit", "person") + result = data.map_to_entity("benunit", "person", columns=["total_benefit"]) # Should have rows for each person assert len(result) == 3 - # Should have benunit data merged in + # Should have benunit data merged in (expanded/replicated) assert "benunit_id" in result.columns assert "person_id" in result.columns assert "total_benefit" in result.columns + # Benefit should be replicated to all persons in benunit + person_benefits = result.set_index("person_id")["total_benefit"].to_dict() + assert person_benefits[1] == 1000 # Person 1 in benunit 1 + assert person_benefits[2] == 1000 # Person 2 in benunit 1 + assert person_benefits[3] == 500 # Person 3 in benunit 2 + def test_map_benunit_to_household(): - """Test mapping benunit-level data to household level.""" + """Test mapping benunit-level data to household level aggregates via person.""" person_df = MicroDataFrame( pd.DataFrame( { "person_id": [1, 2, 3, 4], - "benunit_id": [1, 1, 2, 2], + "benunit_id": [1, 1, 2, 3], "household_id": [1, 1, 2, 2], "person_weight": [1.0, 1.0, 1.0, 1.0], } @@ -199,9 +215,9 @@ def test_map_benunit_to_household(): benunit_df = MicroDataFrame( pd.DataFrame( { - "benunit_id": [1, 2], - "total_benefit": [1000, 500], - "benunit_weight": [1.0, 1.0], + "benunit_id": [1, 2, 3], + "total_benefit": [1000, 500, 300], + "benunit_weight": [1.0, 1.0, 1.0], } ), weights="benunit_weight", @@ -216,13 +232,24 @@ def test_map_benunit_to_household(): person=person_df, benunit=benunit_df, household=household_df ) - result = data.map_to_entity("benunit", "household") + result = data.map_to_entity( + "benunit", "household", columns=["total_benefit"] + ) - # Should have benunit and household data - assert "benunit_id" in result.columns + # Should have household data (aggregated) + assert len(result) == 2 assert "household_id" in result.columns assert "total_benefit" in result.columns + # Benefits should be aggregated at household level + # Household 1 has benunit 1 (1000) + # Household 2 has benunit 2 (500) and benunit 3 (300) = 800 + household_benefits = result.set_index("household_id")[ + "total_benefit" + ].to_dict() + assert household_benefits[1] == 1000 + assert household_benefits[2] == 800 + def test_map_household_to_person(): """Test mapping household-level data to person level.""" @@ -269,7 +296,7 @@ def test_map_household_to_person(): def test_map_household_to_benunit(): - """Test mapping household-level data to benunit level.""" + """Test mapping household-level data to benunit level expands via person.""" person_df = MicroDataFrame( pd.DataFrame( { @@ -302,13 +329,19 @@ def test_map_household_to_benunit(): person=person_df, benunit=benunit_df, household=household_df ) - result = data.map_to_entity("household", "benunit") + result = data.map_to_entity("household", "benunit", columns=["rent"]) - # Should have benunit and household data + # Should have benunit data (expanded from household via person) + # Since benunit-household is 1:1 in this case, should have 2 rows + assert len(result) == 2 assert "benunit_id" in result.columns - assert "household_id" in result.columns assert "rent" in result.columns + # Rent should be mapped from household to benunit + benunit_rents = result.set_index("benunit_id")["rent"].to_dict() + assert benunit_rents[1] == 1000 # Benunit 1 in household 1 + assert benunit_rents[2] == 800 # Benunit 2 in household 2 + def test_map_with_column_selection(): """Test mapping with specific column selection.""" @@ -340,13 +373,15 @@ def test_map_with_column_selection(): person=person_df, benunit=benunit_df, household=household_df ) - # Map only age to household + # Map only age to household (aggregated) result = data.map_to_entity("person", "household", columns=["age"]) assert "age" in result.columns assert "household_id" in result.columns # income should not be included assert "income" not in result.columns + # Should be aggregated to household level + assert len(result) == 2 def test_invalid_entity_names(): diff --git a/tests/test_us_entity_mapping.py b/tests/test_us_entity_mapping.py new file mode 100644 index 00000000..59fe6913 --- /dev/null +++ b/tests/test_us_entity_mapping.py @@ -0,0 +1,333 @@ +import pandas as pd +import pytest +from microdf import MicroDataFrame +from policyengine.tax_benefit_models.us import USYearData + + +def test_map_same_entity(): + """Test mapping from an entity to itself returns the same data.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "household_id": [1, 1, 2], + "tax_unit_id": [1, 1, 2], + "age": [30, 25, 40], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame({"household_id": [1, 2], "household_weight": [1.0, 1.0]}), + weights="household_weight", + ) + + tax_unit_df = MicroDataFrame( + pd.DataFrame({"tax_unit_id": [1, 2], "tax_unit_weight": [1.0, 1.0]}), + weights="tax_unit_weight", + ) + + marital_unit_df = MicroDataFrame( + pd.DataFrame( + {"marital_unit_id": [1, 2], "marital_unit_weight": [1.0, 1.0]} + ), + weights="marital_unit_weight", + ) + + family_df = MicroDataFrame( + pd.DataFrame({"family_id": [1, 2], "family_weight": [1.0, 1.0]}), + weights="family_weight", + ) + + spm_unit_df = MicroDataFrame( + pd.DataFrame({"spm_unit_id": [1, 2], "spm_unit_weight": [1.0, 1.0]}), + weights="spm_unit_weight", + ) + + data = USYearData( + person=person_df, + household=household_df, + tax_unit=tax_unit_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + ) + + # Test person -> person + result = data.map_to_entity("person", "person") + assert isinstance(result, MicroDataFrame) + assert len(result) == 3 + assert list(result["person_id"]) == [1, 2, 3] + + +def test_map_person_to_household_aggregates(): + """Test mapping person-level data to household level aggregates correctly.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "household_id": [1, 1, 2, 2], + "tax_unit_id": [1, 1, 2, 2], + "income": [50000, 30000, 60000, 40000], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "rent": [1000, 800], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + tax_unit_df = MicroDataFrame( + pd.DataFrame({"tax_unit_id": [1, 2], "tax_unit_weight": [1.0, 1.0]}), + weights="tax_unit_weight", + ) + + marital_unit_df = MicroDataFrame( + pd.DataFrame( + {"marital_unit_id": [1, 2], "marital_unit_weight": [1.0, 1.0]} + ), + weights="marital_unit_weight", + ) + + family_df = MicroDataFrame( + pd.DataFrame({"family_id": [1, 2], "family_weight": [1.0, 1.0]}), + weights="family_weight", + ) + + spm_unit_df = MicroDataFrame( + pd.DataFrame({"spm_unit_id": [1, 2], "spm_unit_weight": [1.0, 1.0]}), + weights="spm_unit_weight", + ) + + data = USYearData( + person=person_df, + household=household_df, + tax_unit=tax_unit_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + ) + + result = data.map_to_entity("person", "household", columns=["income"]) + + # Should return household-level data + assert isinstance(result, MicroDataFrame) + assert len(result) == 2 + + # Income should be aggregated (summed) at household level + assert "income" in result.columns + household_incomes = result.set_index("household_id")["income"].to_dict() + assert household_incomes[1] == 80000 # 50000 + 30000 + assert household_incomes[2] == 100000 # 60000 + 40000 + + +def test_map_household_to_person_expands(): + """Test mapping household-level data to person level expands correctly.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3], + "household_id": [1, 1, 2], + "tax_unit_id": [1, 1, 2], + "person_weight": [1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "rent": [1000, 800], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + tax_unit_df = MicroDataFrame( + pd.DataFrame({"tax_unit_id": [1, 2], "tax_unit_weight": [1.0, 1.0]}), + weights="tax_unit_weight", + ) + + marital_unit_df = MicroDataFrame( + pd.DataFrame( + {"marital_unit_id": [1, 2], "marital_unit_weight": [1.0, 1.0]} + ), + weights="marital_unit_weight", + ) + + family_df = MicroDataFrame( + pd.DataFrame({"family_id": [1, 2], "family_weight": [1.0, 1.0]}), + weights="family_weight", + ) + + spm_unit_df = MicroDataFrame( + pd.DataFrame({"spm_unit_id": [1, 2], "spm_unit_weight": [1.0, 1.0]}), + weights="spm_unit_weight", + ) + + data = USYearData( + person=person_df, + household=household_df, + tax_unit=tax_unit_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + ) + + result = data.map_to_entity("household", "person", columns=["rent"]) + + # Should have rows for each person + assert len(result) == 3 + # Should have household data merged in (replicated) + assert "household_id" in result.columns + assert "person_id" in result.columns + assert "rent" in result.columns + + # Rent should be replicated to all persons in household + person_rents = result.set_index("person_id")["rent"].to_dict() + assert person_rents[1] == 1000 # Person 1 in household 1 + assert person_rents[2] == 1000 # Person 2 in household 1 + assert person_rents[3] == 800 # Person 3 in household 2 + + +def test_map_tax_unit_to_household_via_person(): + """Test mapping tax_unit to household goes through person and aggregates.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "household_id": [1, 1, 2, 2], + "tax_unit_id": [1, 1, 2, 3], + "person_weight": [1.0, 1.0, 1.0, 1.0], + } + ), + weights="person_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame({"household_id": [1, 2], "household_weight": [1.0, 1.0]}), + weights="household_weight", + ) + + tax_unit_df = MicroDataFrame( + pd.DataFrame( + { + "tax_unit_id": [1, 2, 3], + "taxable_income": [80000, 60000, 40000], + "tax_unit_weight": [1.0, 1.0, 1.0], + } + ), + weights="tax_unit_weight", + ) + + marital_unit_df = MicroDataFrame( + pd.DataFrame( + {"marital_unit_id": [1, 2], "marital_unit_weight": [1.0, 1.0]} + ), + weights="marital_unit_weight", + ) + + family_df = MicroDataFrame( + pd.DataFrame({"family_id": [1, 2], "family_weight": [1.0, 1.0]}), + weights="family_weight", + ) + + spm_unit_df = MicroDataFrame( + pd.DataFrame({"spm_unit_id": [1, 2], "spm_unit_weight": [1.0, 1.0]}), + weights="spm_unit_weight", + ) + + data = USYearData( + person=person_df, + household=household_df, + tax_unit=tax_unit_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + ) + + result = data.map_to_entity( + "tax_unit", "household", columns=["taxable_income"] + ) + + # Should return household-level data + assert len(result) == 2 + assert "taxable_income" in result.columns + + # Income should be aggregated at household level + # Household 1 has tax_unit 1 (80000) + # Household 2 has tax_unit 2 (60000) and tax_unit 3 (40000) = 100000 + household_incomes = result.set_index("household_id")[ + "taxable_income" + ].to_dict() + assert household_incomes[1] == 80000 + assert household_incomes[2] == 100000 + + +def test_invalid_entity_names(): + """Test that invalid entity names raise ValueError.""" + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1], + "household_id": [1], + "tax_unit_id": [1], + "person_weight": [1.0], + } + ), + weights="person_weight", + ) + + household_df = MicroDataFrame( + pd.DataFrame({"household_id": [1], "household_weight": [1.0]}), + weights="household_weight", + ) + + tax_unit_df = MicroDataFrame( + pd.DataFrame({"tax_unit_id": [1], "tax_unit_weight": [1.0]}), + weights="tax_unit_weight", + ) + + marital_unit_df = MicroDataFrame( + pd.DataFrame({"marital_unit_id": [1], "marital_unit_weight": [1.0]}), + weights="marital_unit_weight", + ) + + family_df = MicroDataFrame( + pd.DataFrame({"family_id": [1], "family_weight": [1.0]}), + weights="family_weight", + ) + + spm_unit_df = MicroDataFrame( + pd.DataFrame({"spm_unit_id": [1], "spm_unit_weight": [1.0]}), + weights="spm_unit_weight", + ) + + data = USYearData( + person=person_df, + household=household_df, + tax_unit=tax_unit_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + ) + + with pytest.raises(ValueError, match="Invalid source entity"): + data.map_to_entity("invalid", "person") + + with pytest.raises(ValueError, match="Invalid target entity"): + data.map_to_entity("person", "invalid") From 4fbe9948b71c4694b098b76e340a35405f96481e Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 16 Nov 2025 23:39:48 +0000 Subject: [PATCH 29/35] Versioning --- changelog_entry.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..4c132743 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + - Just basemodels, no sqlmodels. + - Clean, working analysis at both household and macro level for uk and us. From 6e10d90a3b0e065332d907f653a1c8d51b912cab Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sun, 16 Nov 2025 23:51:23 +0000 Subject: [PATCH 30/35] Add docs --- README.md | 190 +++++++++++++++- docs/core-concepts.md | 443 ++++++++++++++++++++++++++++++++++++++ docs/country-models-uk.md | 340 +++++++++++++++++++++++++++++ docs/country-models-us.md | 413 +++++++++++++++++++++++++++++++++++ docs/dev.md | 2 +- docs/myst.yml | 4 +- 6 files changed, 1382 insertions(+), 10 deletions(-) create mode 100644 docs/core-concepts.md create mode 100644 docs/country-models-uk.md create mode 100644 docs/country-models-us.md diff --git a/README.md b/README.md index 11d18575..8c89c37f 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,186 @@ # PolicyEngine.py -Documentation +A Python package for tax-benefit microsimulation analysis. Run policy simulations, analyse distributional impacts, and visualise results across the UK and US. -- Parameters, variables, and values: `docs/01_parameters_variables.ipynb` -- Policies and dynamic: `docs/02_policies_dynamic.ipynb` -- Datasets: `docs/03_datasets.ipynb` -- Simulations: `docs/04_simulations.ipynb` -- Output data items: `docs/05_output_data_items.ipynb` -- Reports and users: `docs/06_reports_users.ipynb` +## Quick start -Open these notebooks in Jupyter or your preferred IDE to run the examples. +```python +from policyengine.core import Simulation +from policyengine.tax_benefit_models.uk import PolicyEngineUKDataset, uk_latest +from policyengine.outputs.aggregate import Aggregate, AggregateType + +# Load representative microdata +dataset = PolicyEngineUKDataset( + name="FRS 2023-24", + filepath="./data/frs_2023_24_year_2026.h5", + year=2026, +) + +# Run simulation +simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, +) +simulation.run() + +# Calculate total universal credit spending +agg = Aggregate( + simulation=simulation, + variable="universal_credit", + aggregate_type=AggregateType.SUM, + entity="benunit", +) +agg.run() +print(f"Total UC spending: £{agg.result / 1e9:.1f}bn") +``` + +## Documentation + +**Core concepts:** +- [Core concepts](docs/core-concepts.md): Architecture, datasets, simulations, outputs +- [UK tax-benefit model](docs/country-models-uk.md): Entities, parameters, examples +- [US tax-benefit model](docs/country-models-us.md): Entities, parameters, examples + +**Examples:** +- `examples/income_distribution_us.py`: Analyse benefit distribution by decile +- `examples/employment_income_variation_uk.py`: Model employment income phase-outs +- `examples/policy_change_uk.py`: Analyse policy reform impacts + +## Installation + +```bash +pip install policyengine +``` + +## Features + +- **Multi-country support**: UK and US tax-benefit systems +- **Representative microdata**: Load FRS, CPS, or create custom scenarios +- **Policy reforms**: Parametric reforms with date-bound parameter values +- **Distributional analysis**: Aggregate statistics by income decile, demographics +- **Entity mapping**: Automatic mapping between person, household, tax unit levels +- **Visualisation**: PolicyEngine-branded charts with Plotly + +## Key concepts + +### Datasets + +Datasets contain microdata at entity level (person, household, tax unit). Load representative data or create custom scenarios: + +```python +from policyengine.tax_benefit_models.uk import PolicyEngineUKDataset + +dataset = PolicyEngineUKDataset( + name="Representative data", + filepath="./data/frs_2023_24_year_2026.h5", + year=2026, +) +dataset.load() +``` + +### Simulations + +Simulations apply tax-benefit models to datasets: + +```python +from policyengine.core import Simulation +from policyengine.tax_benefit_models.uk import uk_latest + +simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, +) +simulation.run() + +# Access calculated variables +output = simulation.output_dataset.data +print(output.household[["household_net_income", "household_benefits"]]) +``` + +### Outputs + +Extract insights with aggregate statistics: + +```python +from policyengine.outputs.aggregate import Aggregate, AggregateType + +# Mean income in top decile +agg = Aggregate( + simulation=simulation, + variable="household_net_income", + aggregate_type=AggregateType.MEAN, + filter_variable="household_net_income", + quantile=10, + quantile_eq=10, +) +agg.run() +print(f"Top decile mean income: £{agg.result:,.0f}") +``` + +### Policy reforms + +Apply parametric reforms: + +```python +from policyengine.core import Policy, Parameter, ParameterValue +import datetime + +parameter = Parameter( + name="gov.hmrc.income_tax.allowances.personal_allowance.amount", + tax_benefit_model_version=uk_latest, + data_type=float, +) + +policy = Policy( + name="Increase personal allowance", + parameter_values=[ + ParameterValue( + parameter=parameter, + start_date=datetime.date(2026, 1, 1), + end_date=datetime.date(2026, 12, 31), + value=15000, + ) + ], +) + +# Run reform simulation +reform_sim = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + policy=policy, +) +reform_sim.run() +``` + +## Country models + +### UK + +Three entity levels: +- **Person**: Individual with income and demographics +- **Benunit**: Benefit unit (single person or couple with children) +- **Household**: Residence unit + +Key benefits: Universal Credit, Child Benefit, Pension Credit +Key taxes: Income tax, National Insurance + +### US + +Six entity levels: +- **Person**: Individual +- **Tax unit**: Federal tax filing unit +- **SPM unit**: Supplemental Poverty Measure unit +- **Family**: Census family definition +- **Marital unit**: Married couple or single person +- **Household**: Residence unit + +Key benefits: SNAP, TANF, EITC, CTC, SSI, Social Security +Key taxes: Federal income tax, payroll tax + +## Contributing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for development setup and guidelines. + +## License + +AGPL-3.0 diff --git a/docs/core-concepts.md b/docs/core-concepts.md new file mode 100644 index 00000000..bdba1f39 --- /dev/null +++ b/docs/core-concepts.md @@ -0,0 +1,443 @@ +# Core concepts + +PolicyEngine.py is a Python package for tax-benefit microsimulation analysis. It provides a unified interface for running policy simulations, analysing distributional impacts, and visualising results across different countries. + +## Architecture overview + +The package is organised around several core concepts: + +- **Tax-benefit models**: Country-specific implementations (UK, US) that define tax and benefit rules +- **Datasets**: Microdata representing populations at entity level (person, household, etc.) +- **Simulations**: Execution environments that apply tax-benefit models to datasets +- **Outputs**: Analysis tools for extracting insights from simulation results +- **Policies**: Parametric reforms that modify tax-benefit system parameters + +## Tax-benefit models + +Tax-benefit models define the rules and calculations for a country's tax and benefit system. Each model version contains: + +- **Variables**: Calculated values (e.g., income tax, universal credit) +- **Parameters**: System settings (e.g., personal allowance, benefit rates) +- **Parameter values**: Time-bound values for parameters + +### Using a tax-benefit model + +```python +from policyengine.tax_benefit_models.uk import uk_latest +from policyengine.tax_benefit_models.us import us_latest + +# UK model includes variables like: +# - income_tax, national_insurance, universal_credit +# - Parameters like personal allowance, NI thresholds + +# US model includes variables like: +# - income_tax, payroll_tax, eitc, ctc, snap +# - Parameters like standard deduction, EITC rates +``` + +## Datasets + +Datasets contain microdata representing a population. Each dataset has: + +- **Entity-level data**: Separate dataframes for person, household, and other entities +- **Weights**: Survey weights for population representation +- **Join keys**: Relationships between entities (e.g., which household each person belongs to) + +### Dataset structure + +```python +from policyengine.tax_benefit_models.uk import PolicyEngineUKDataset + +dataset = PolicyEngineUKDataset( + name="FRS 2023-24", + description="Family Resources Survey microdata", + filepath="./data/frs_2023_24_year_2026.h5", + year=2026, +) + +# Access entity-level data +person_data = dataset.data.person # MicroDataFrame +household_data = dataset.data.household +benunit_data = dataset.data.benunit # Benefit unit (UK only) +``` + +### Creating custom datasets + +You can create custom datasets for scenario analysis: + +```python +import pandas as pd +from microdf import MicroDataFrame +from policyengine.tax_benefit_models.uk import PolicyEngineUKDataset, UKYearData + +# Create person data +person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [0, 1, 2], + "person_household_id": [0, 0, 1], + "person_benunit_id": [0, 0, 1], + "age": [35, 8, 40], + "employment_income": [30000, 0, 50000], + "person_weight": [1.0, 1.0, 1.0], + }), + weights="person_weight" +) + +# Create household data +household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [0, 1], + "region": ["LONDON", "SOUTH_EAST"], + "rent": [15000, 12000], + "household_weight": [1.0, 1.0], + }), + weights="household_weight" +) + +# Create benunit data +benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [0, 1], + "would_claim_uc": [True, True], + "benunit_weight": [1.0, 1.0], + }), + weights="benunit_weight" +) + +dataset = PolicyEngineUKDataset( + name="Custom scenario", + description="Single parent vs single adult", + filepath="./custom.h5", + year=2026, + data=UKYearData( + person=person_df, + household=household_df, + benunit=benunit_df, + ) +) +``` + +## Simulations + +Simulations apply tax-benefit models to datasets, calculating all variables for the specified year. + +### Running a simulation + +```python +from policyengine.core import Simulation +from policyengine.tax_benefit_models.uk import uk_latest + +simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, +) +simulation.run() + +# Access output data +output_person = simulation.output_dataset.data.person +output_household = simulation.output_dataset.data.household + +# Check calculated variables +print(output_household[["household_id", "household_net_income", "household_tax"]]) +``` + +### Accessing calculated variables + +After running a simulation, you can access the calculated variables from the output dataset: + +```python +simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, +) +simulation.run() + +# Access specific variables +output = simulation.output_dataset.data +person_data = output.person[["person_id", "age", "employment_income", "income_tax"]] +household_data = output.household[["household_id", "household_net_income"]] +benunit_data = output.benunit[["benunit_id", "universal_credit", "child_benefit"]] +``` + +## Policies + +Policies modify tax-benefit system parameters through parametric reforms. + +### Creating a policy + +```python +from policyengine.core import Policy, Parameter, ParameterValue +import datetime + +# Define parameter to modify +parameter = Parameter( + name="gov.hmrc.income_tax.allowances.personal_allowance.amount", + tax_benefit_model_version=uk_latest, + description="Personal allowance for income tax", + data_type=float, +) + +# Set new value +parameter_value = ParameterValue( + parameter=parameter, + start_date=datetime.date(2026, 1, 1), + end_date=datetime.date(2026, 12, 31), + value=15000, # Increase from ~£12,570 to £15,000 +) + +policy = Policy( + name="Increased personal allowance", + description="Raises personal allowance to £15,000", + parameter_values=[parameter_value], +) +``` + +### Running a reform simulation + +```python +# Baseline simulation +baseline = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, +) +baseline.run() + +# Reform simulation +reform = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + policy=policy, +) +reform.run() +``` + +## Outputs + +Output classes provide structured analysis of simulation results. + +### Aggregate + +Calculate aggregate statistics (sum, mean, count) for any variable: + +```python +from policyengine.outputs.aggregate import Aggregate, AggregateType + +# Total universal credit spending +agg = Aggregate( + simulation=simulation, + variable="universal_credit", + aggregate_type=AggregateType.SUM, + entity="benunit", # Map to benunit level +) +agg.run() +print(f"Total UC spending: £{agg.result / 1e9:.1f}bn") + +# Mean household income in top decile +agg = Aggregate( + simulation=simulation, + variable="household_net_income", + aggregate_type=AggregateType.MEAN, + filter_variable="household_net_income", + quantile=10, + quantile_eq=10, # 10th decile +) +agg.run() +print(f"Mean income in top decile: £{agg.result:,.0f}") +``` + +### ChangeAggregate + +Analyse impacts of policy reforms: + +```python +from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType + +# Count winners and losers +winners = ChangeAggregate( + baseline_simulation=baseline, + reform_simulation=reform, + variable="household_net_income", + aggregate_type=ChangeAggregateType.COUNT, + change_geq=1, # Gain at least £1 +) +winners.run() +print(f"Winners: {winners.result / 1e6:.1f}m households") + +losers = ChangeAggregate( + baseline_simulation=baseline, + reform_simulation=reform, + variable="household_net_income", + aggregate_type=ChangeAggregateType.COUNT, + change_leq=-1, # Lose at least £1 +) +losers.run() +print(f"Losers: {losers.result / 1e6:.1f}m households") + +# Revenue impact +revenue = ChangeAggregate( + baseline_simulation=baseline, + reform_simulation=reform, + variable="household_tax", + aggregate_type=ChangeAggregateType.SUM, +) +revenue.run() +print(f"Revenue change: £{revenue.result / 1e9:.1f}bn") +``` + +## Entity mapping + +The package automatically handles entity mapping when variables are defined at different entity levels. + +### Entity hierarchy + +**UK:** +``` +household + └── benunit (benefit unit) + └── person +``` + +**US:** +``` +household + ├── tax_unit + ├── spm_unit + ├── family + └── marital_unit + └── person +``` + +### Automatic mapping + +When you request a person-level variable (like `ssi`) at household level, the package: +1. Sums person-level values within each household (aggregation) +2. Returns household-level data with proper weights + +```python +# SSI is defined at person level, but we want household-level totals +agg = Aggregate( + simulation=simulation, + variable="ssi", # Person-level variable + entity="household", # Target household level + aggregate_type=AggregateType.SUM, +) +# Internally maps person → household by summing SSI for all persons in each household +``` + +When you request a household-level variable at person level: +1. Replicates household values to all persons in that household (expansion) + +## Visualisation + +The package includes utilities for creating PolicyEngine-branded visualisations: + +```python +from policyengine.utils.plotting import format_fig, COLORS +import plotly.graph_objects as go + +fig = go.Figure() +fig.add_trace(go.Scatter(x=[1, 2, 3], y=[4, 5, 6])) + +format_fig( + fig, + title="My chart", + xaxis_title="X axis", + yaxis_title="Y axis", + height=600, + width=800, +) +fig.show() +``` + +### Brand colours + +```python +COLORS = { + "primary": "#319795", # Teal + "success": "#22C55E", # Green + "warning": "#FEC601", # Yellow + "error": "#EF4444", # Red + "info": "#1890FF", # Blue + "blue_secondary": "#026AA2", # Dark blue + "gray": "#667085", # Gray +} +``` + +## Common workflows + +### 1. Analyse employment income variation + +See `examples/employment_income_variation_uk.py` for a complete example of: +- Creating custom datasets with varied parameters +- Running single simulations +- Extracting results with filters +- Visualising benefit phase-outs + +### 2. Policy reform analysis + +See `examples/policy_change_uk.py` for: +- Applying parametric reforms +- Comparing baseline and reform +- Analysing winners/losers by decile +- Calculating revenue impacts + +### 3. Distributional analysis + +See `examples/income_distribution_us.py` for: +- Loading representative microdata +- Calculating statistics by income decile +- Mapping variables across entity levels +- Creating interactive visualisations + +## Best practices + +### Creating custom datasets + +1. **Always set would_claim variables**: Benefits won't be claimed unless explicitly enabled + ```python + "would_claim_uc": [True] * n_households + ``` + +2. **Set disability variables explicitly**: Prevents random UC spikes from LCWRA element + ```python + "is_disabled_for_benefits": [False] * n_people + "uc_limited_capability_for_WRA": [False] * n_people + ``` + +3. **Include required join keys**: Person data needs entity membership + ```python + "person_household_id": household_ids + "person_benunit_id": benunit_ids # UK only + ``` + +4. **Set required household fields**: Vary by country + ```python + # UK + "region": ["LONDON"] * n_households + "tenure_type": ["RENT_PRIVATELY"] * n_households + + # US + "state_code": ["CA"] * n_households + ``` + +### Performance optimisation + +1. **Single simulation for variations**: Create all scenarios in one dataset, run once +2. **Custom variable selection**: Only calculate needed variables +3. **Filter efficiently**: Use quantile filters for decile analysis +4. **Parallel analysis**: Multiple Aggregate calls can run independently + +### Data integrity + +1. **Check weights**: Ensure weights sum to expected population +2. **Validate join keys**: All persons should link to valid households +3. **Review output ranges**: Check calculated values are reasonable +4. **Test edge cases**: Zero income, high income, disabled, elderly + +## Next steps + +- See `examples/` for complete working examples +- Review country-specific documentation: + - [UK tax-benefit model](country-models-uk.md) + - [US tax-benefit model](country-models-us.md) +- Explore the API reference for detailed class documentation diff --git a/docs/country-models-uk.md b/docs/country-models-uk.md new file mode 100644 index 00000000..27d7dae7 --- /dev/null +++ b/docs/country-models-uk.md @@ -0,0 +1,340 @@ +# UK tax-benefit model + +The UK tax-benefit model implements the United Kingdom's tax and benefit system using PolicyEngine UK as the underlying calculation engine. + +## Entity structure + +The UK model uses three entity levels: + +``` +household + └── benunit (benefit unit) + └── person +``` + +### Person + +Individual people with demographic and income characteristics. + +**Key variables:** +- `age`: Person's age in years +- `employment_income`: Annual employment income +- `self_employment_income`: Annual self-employment income +- `pension_income`: Annual pension income +- `savings_interest_income`: Annual interest from savings +- `dividend_income`: Annual dividend income +- `income_tax`: Total income tax paid +- `national_insurance`: Total NI contributions +- `is_disabled_for_benefits`: Whether disabled for benefit purposes + +### Benunit (benefit unit) + +The unit for benefit assessment. Usually a single person or a couple with dependent children. + +**Key variables:** +- `universal_credit`: Annual UC payment +- `child_benefit`: Annual child benefit +- `working_tax_credit`: Annual WTC (legacy system) +- `child_tax_credit`: Annual CTC (legacy system) +- `pension_credit`: Annual pension credit +- `income_support`: Annual income support +- `housing_benefit`: Annual housing benefit +- `council_tax_support`: Annual council tax support + +**Important flags:** +- `would_claim_uc`: Must be True to claim UC +- `would_claim_WTC`: Must be True to claim WTC +- `would_claim_CTC`: Must be True to claim CTC +- `would_claim_IS`: Must be True to claim IS +- `would_claim_pc`: Must be True to claim pension credit +- `would_claim_child_benefit`: Must be True to claim child benefit +- `would_claim_housing_benefit`: Must be True to claim HB + +### Household + +The residence unit, typically sharing accommodation. + +**Key variables:** +- `household_net_income`: Total household net income +- `hbai_household_net_income`: HBAI-equivalised net income +- `household_benefits`: Total benefits received +- `household_tax`: Total tax paid +- `household_market_income`: Total market income + +**Required fields:** +- `region`: UK region (e.g., "LONDON", "SOUTH_EAST") +- `tenure_type`: Housing tenure (e.g., "RENT_PRIVATELY", "OWNED_OUTRIGHT") +- `rent`: Annual rent paid +- `council_tax`: Annual council tax + +## Using the UK model + +### Loading representative data + +```python +from policyengine.tax_benefit_models.uk import PolicyEngineUKDataset + +dataset = PolicyEngineUKDataset( + name="FRS 2023-24", + description="Family Resources Survey microdata", + filepath="./data/frs_2023_24_year_2026.h5", + year=2026, +) + +print(f"People: {len(dataset.data.person):,}") +print(f"Benefit units: {len(dataset.data.benunit):,}") +print(f"Households: {len(dataset.data.household):,}") +``` + +### Creating custom scenarios + +```python +import pandas as pd +from microdf import MicroDataFrame +from policyengine.tax_benefit_models.uk import UKYearData + +# Single parent with 2 children +person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [0, 1, 2], + "person_benunit_id": [0, 0, 0], + "person_household_id": [0, 0, 0], + "age": [35, 8, 5], + "employment_income": [25000, 0, 0], + "person_weight": [1.0, 1.0, 1.0], + "is_disabled_for_benefits": [False, False, False], + "uc_limited_capability_for_WRA": [False, False, False], + }), + weights="person_weight" +) + +benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [0], + "benunit_weight": [1.0], + "would_claim_uc": [True], + "would_claim_child_benefit": [True], + "would_claim_WTC": [True], + "would_claim_CTC": [True], + }), + weights="benunit_weight" +) + +household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [0], + "household_weight": [1.0], + "region": ["LONDON"], + "rent": [15000], # £1,250/month + "council_tax": [2000], + "tenure_type": ["RENT_PRIVATELY"], + }), + weights="household_weight" +) + +dataset = PolicyEngineUKDataset( + name="Single parent scenario", + description="One adult, two children", + filepath="./single_parent.h5", + year=2026, + data=UKYearData( + person=person_df, + benunit=benunit_df, + household=household_df, + ) +) +``` + +### Running a simulation + +```python +from policyengine.core import Simulation +from policyengine.tax_benefit_models.uk import uk_latest + +simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, +) +simulation.run() + +# Check results +output = simulation.output_dataset.data +print(output.household[["household_net_income", "household_benefits", "household_tax"]]) +``` + +## Key parameters + +### Income tax + +- `gov.hmrc.income_tax.allowances.personal_allowance.amount`: Personal allowance (£12,570 in 2024-25) +- `gov.hmrc.income_tax.rates.uk[0].rate`: Basic rate (20%) +- `gov.hmrc.income_tax.rates.uk[1].rate`: Higher rate (40%) +- `gov.hmrc.income_tax.rates.uk[2].rate`: Additional rate (45%) +- `gov.hmrc.income_tax.rates.uk[0].threshold`: Basic rate threshold (£50,270) +- `gov.hmrc.income_tax.rates.uk[1].threshold`: Higher rate threshold (£125,140) + +### National insurance + +- `gov.hmrc.national_insurance.class_1.main.primary_threshold`: Primary threshold (£12,570) +- `gov.hmrc.national_insurance.class_1.main.upper_earnings_limit`: Upper earnings limit (£50,270) +- `gov.hmrc.national_insurance.class_1.main.rate`: Main rate (12% below UEL, 2% above) + +### Universal credit + +- `gov.dwp.universal_credit.elements.standard_allowance.single_adult`: Standard allowance for single adult (£334.91/month in 2024-25) +- `gov.dwp.universal_credit.elements.child.first_child`: First child element (£333.33/month) +- `gov.dwp.universal_credit.elements.child.subsequent_child`: Subsequent children (£287.92/month each) +- `gov.dwp.universal_credit.means_test.reduction_rate`: Taper rate (55%) +- `gov.dwp.universal_credit.means_test.earned_income.disregard`: Work allowance + +### Child benefit + +- `gov.hmrc.child_benefit.rates.eldest_child`: First child rate (£25.60/week) +- `gov.hmrc.child_benefit.rates.additional_child`: Additional children (£16.95/week each) +- `gov.hmrc.child_benefit.income_tax_charge.threshold`: HICBC threshold (£60,000) + +## Common policy reforms + +### Increasing personal allowance + +```python +from policyengine.core import Policy, Parameter, ParameterValue +import datetime + +parameter = Parameter( + name="gov.hmrc.income_tax.allowances.personal_allowance.amount", + tax_benefit_model_version=uk_latest, + description="Personal allowance", + data_type=float, +) + +policy = Policy( + name="Increase personal allowance to £15,000", + description="Raises personal allowance from £12,570 to £15,000", + parameter_values=[ + ParameterValue( + parameter=parameter, + start_date=datetime.date(2026, 1, 1), + end_date=datetime.date(2026, 12, 31), + value=15000, + ) + ], +) +``` + +### Adjusting UC taper rate + +```python +parameter = Parameter( + name="gov.dwp.universal_credit.means_test.reduction_rate", + tax_benefit_model_version=uk_latest, + description="UC taper rate", + data_type=float, +) + +policy = Policy( + name="Reduce UC taper to 50%", + description="Lowers taper rate from 55% to 50%", + parameter_values=[ + ParameterValue( + parameter=parameter, + start_date=datetime.date(2026, 1, 1), + end_date=datetime.date(2026, 12, 31), + value=0.50, # 50% + ) + ], +) +``` + +### Abolishing two-child limit + +```python +# Set subsequent child element equal to first child +parameter = Parameter( + name="gov.dwp.universal_credit.elements.child.subsequent_child", + tax_benefit_model_version=uk_latest, + description="UC subsequent child element", + data_type=float, +) + +policy = Policy( + name="Abolish two-child limit", + description="Sets subsequent child element equal to first child", + parameter_values=[ + ParameterValue( + parameter=parameter, + start_date=datetime.date(2026, 1, 1), + end_date=datetime.date(2026, 12, 31), + value=333.33, # Match first child rate + ) + ], +) +``` + +## Regional variations + +The UK model accounts for regional differences: + +- **Council tax**: Varies by local authority +- **Rent levels**: Regional housing markets +- **Scottish income tax**: Different rates and thresholds for Scottish taxpayers + +### Regions + +Valid region values: +- `LONDON` +- `SOUTH_EAST` +- `SOUTH_WEST` +- `EAST_OF_ENGLAND` +- `WEST_MIDLANDS` +- `EAST_MIDLANDS` +- `YORKSHIRE` +- `NORTH_WEST` +- `NORTH_EAST` +- `WALES` +- `SCOTLAND` +- `NORTHERN_IRELAND` + +## Data sources + +The UK model can use several data sources: + +1. **Family Resources Survey (FRS)**: Official UK household survey + - ~19,000 households + - Detailed income and benefit receipt + - Published annually + +2. **Enhanced FRS**: Uprated and enhanced version + - Calibrated to population totals + - Additional imputed variables + - Multiple projection years + +3. **Custom datasets**: User-created scenarios + - Full control over household composition + - Exact income levels + - Specific benefit claiming patterns + +## Validation + +When creating custom datasets, validate: + +1. **Would claim flags**: All set to True +2. **Disability flags**: Set explicitly (not random) +3. **Join keys**: Person data links to benunits and households +4. **Required fields**: Region, tenure_type set correctly +5. **Weights**: Sum to expected values +6. **Income ranges**: Realistic values + +## Examples + +See working examples in the `examples/` directory: + +- `employment_income_variation_uk.py`: Vary employment income, analyse benefit phase-outs +- `policy_change_uk.py`: Apply reforms, analyse winners/losers +- `income_bands_uk.py`: Create income band scenarios + +## References + +- PolicyEngine UK documentation: https://policyengine.github.io/policyengine-uk/ +- UK tax-benefit system: https://www.gov.uk/browse/benefits +- HBAI methodology: https://www.gov.uk/government/statistics/households-below-average-income-for-financial-years-ending-1995-to-2023 diff --git a/docs/country-models-us.md b/docs/country-models-us.md new file mode 100644 index 00000000..927966ee --- /dev/null +++ b/docs/country-models-us.md @@ -0,0 +1,413 @@ +# US tax-benefit model + +The US tax-benefit model implements the United States federal tax and benefit system using PolicyEngine US as the underlying calculation engine. + +## Entity structure + +The US model uses a more complex entity hierarchy: + +``` +household + ├── tax_unit (federal tax filing unit) + ├── spm_unit (Supplemental Poverty Measure unit) + ├── family (Census definition) + └── marital_unit (married couple or single person) + └── person +``` + +### Person + +Individual people with demographic and income characteristics. + +**Key variables:** +- `age`: Person's age in years +- `employment_income`: Annual employment income +- `self_employment_income`: Annual self-employment income +- `social_security`: Annual Social Security benefits +- `ssi`: Annual Supplemental Security Income +- `medicaid`: Annual Medicaid value +- `medicare`: Annual Medicare value +- `unemployment_compensation`: Annual unemployment benefits + +### Tax unit + +The federal tax filing unit (individual or married filing jointly). + +**Key variables:** +- `income_tax`: Federal income tax liability +- `employee_payroll_tax`: Employee payroll tax (FICA) +- `eitc`: Earned Income Tax Credit +- `ctc`: Child Tax Credit +- `income_tax_before_credits`: Tax before credits + +### SPM unit + +The Supplemental Poverty Measure unit used for SNAP and other means-tested benefits. + +**Key variables:** +- `snap`: Annual SNAP (food stamps) benefits +- `tanf`: Annual TANF (cash assistance) benefits +- `spm_unit_net_income`: SPM net income +- `spm_unit_size`: Number of people in unit + +### Family + +Census definition of family (related individuals). + +**Key variables:** +- `family_id`: Family identifier +- `family_weight`: Survey weight + +### Marital unit + +Married couple or single person. + +**Key variables:** +- `marital_unit_id`: Marital unit identifier +- `marital_unit_weight`: Survey weight + +### Household + +The residence unit. + +**Key variables:** +- `household_net_income`: Total household net income +- `household_benefits`: Total benefits received +- `household_tax`: Total tax paid +- `household_market_income`: Total market income before taxes and transfers + +**Required fields:** +- `state_code`: State (e.g., "CA", "NY", "TX") + +## Using the US model + +### Loading representative data + +```python +from policyengine.tax_benefit_models.us import PolicyEngineUSDataset + +dataset = PolicyEngineUSDataset( + name="Enhanced CPS 2024", + description="Enhanced Current Population Survey microdata", + filepath="./data/enhanced_cps_2024_year_2024.h5", + year=2024, +) + +print(f"People: {len(dataset.data.person):,}") +print(f"Tax units: {len(dataset.data.tax_unit):,}") +print(f"SPM units: {len(dataset.data.spm_unit):,}") +print(f"Households: {len(dataset.data.household):,}") +``` + +### Creating custom scenarios + +```python +import pandas as pd +from microdf import MicroDataFrame +from policyengine.tax_benefit_models.us import USYearData + +# Married couple with 2 children +person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [0, 1, 2, 3], + "person_household_id": [0, 0, 0, 0], + "person_tax_unit_id": [0, 0, 0, 0], + "person_spm_unit_id": [0, 0, 0, 0], + "person_family_id": [0, 0, 0, 0], + "person_marital_unit_id": [0, 0, 1, 2], + "age": [35, 33, 8, 5], + "employment_income": [60000, 40000, 0, 0], + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" +) + +tax_unit_df = MicroDataFrame( + pd.DataFrame({ + "tax_unit_id": [0], + "tax_unit_weight": [1.0], + }), + weights="tax_unit_weight" +) + +spm_unit_df = MicroDataFrame( + pd.DataFrame({ + "spm_unit_id": [0], + "spm_unit_weight": [1.0], + }), + weights="spm_unit_weight" +) + +family_df = MicroDataFrame( + pd.DataFrame({ + "family_id": [0], + "family_weight": [1.0], + }), + weights="family_weight" +) + +marital_unit_df = MicroDataFrame( + pd.DataFrame({ + "marital_unit_id": [0, 1, 2], + "marital_unit_weight": [1.0, 1.0, 1.0], + }), + weights="marital_unit_weight" +) + +household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [0], + "household_weight": [1.0], + "state_code": ["CA"], + }), + weights="household_weight" +) + +dataset = PolicyEngineUSDataset( + name="Married couple scenario", + description="Two adults, two children", + filepath="./married_couple.h5", + year=2024, + data=USYearData( + person=person_df, + tax_unit=tax_unit_df, + spm_unit=spm_unit_df, + family=family_df, + marital_unit=marital_unit_df, + household=household_df, + ) +) +``` + +### Running a simulation + +```python +from policyengine.core import Simulation +from policyengine.tax_benefit_models.us import us_latest + +simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=us_latest, +) +simulation.run() + +# Check results +output = simulation.output_dataset.data +print(output.household[["household_net_income", "household_benefits", "household_tax"]]) +``` + +## Key parameters + +### Income tax + +- `gov.irs.income.standard_deduction.joint`: Standard deduction (married filing jointly) +- `gov.irs.income.standard_deduction.single`: Standard deduction (single) +- `gov.irs.income.bracket.rates[0]`: 10% bracket rate +- `gov.irs.income.bracket.rates[1]`: 12% bracket rate +- `gov.irs.income.bracket.rates[2]`: 22% bracket rate +- `gov.irs.income.bracket.thresholds.joint[0]`: 10% bracket threshold (MFJ) +- `gov.irs.income.bracket.thresholds.single[0]`: 10% bracket threshold (single) + +### Payroll tax + +- `gov.ssa.payroll.rate.employee`: Employee OASDI rate (6.2%) +- `gov.medicare.payroll.rate`: Medicare rate (1.45%) +- `gov.ssa.payroll.cap`: OASDI wage base ($168,600 in 2024) + +### Child Tax Credit + +- `gov.irs.credits.ctc.amount.base`: Base CTC amount ($2,000 per child) +- `gov.irs.credits.ctc.refundable.amount.max`: Maximum refundable amount ($1,700) +- `gov.irs.credits.ctc.phase_out.threshold.joint`: Phase-out threshold (MFJ) +- `gov.irs.credits.ctc.phase_out.rate`: Phase-out rate + +### Earned Income Tax Credit + +- `gov.irs.credits.eitc.max[0]`: Maximum EITC (0 children) +- `gov.irs.credits.eitc.max[1]`: Maximum EITC (1 child) +- `gov.irs.credits.eitc.max[2]`: Maximum EITC (2 children) +- `gov.irs.credits.eitc.max[3]`: Maximum EITC (3+ children) +- `gov.irs.credits.eitc.phase_out.start[0]`: Phase-out start (0 children) +- `gov.irs.credits.eitc.phase_out.rate[0]`: Phase-out rate (0 children) + +### SNAP + +- `gov.usda.snap.normal_allotment.max[1]`: Maximum benefit (1 person) +- `gov.usda.snap.normal_allotment.max[2]`: Maximum benefit (2 people) +- `gov.usda.snap.income_limit.net`: Net income limit (100% FPL) +- `gov.usda.snap.income_deduction.earned.rate`: Earned income deduction rate (20%) + +## Common policy reforms + +### Increasing standard deduction + +```python +from policyengine.core import Policy, Parameter, ParameterValue +import datetime + +parameter = Parameter( + name="gov.irs.income.standard_deduction.single", + tax_benefit_model_version=us_latest, + description="Standard deduction (single)", + data_type=float, +) + +policy = Policy( + name="Increase standard deduction to $20,000", + description="Raises single standard deduction from $14,600 to $20,000", + parameter_values=[ + ParameterValue( + parameter=parameter, + start_date=datetime.date(2024, 1, 1), + end_date=datetime.date(2024, 12, 31), + value=20000, + ) + ], +) +``` + +### Expanding Child Tax Credit + +```python +parameter = Parameter( + name="gov.irs.credits.ctc.amount.base", + tax_benefit_model_version=us_latest, + description="Base CTC amount", + data_type=float, +) + +policy = Policy( + name="Increase CTC to $3,000", + description="Expands CTC from $2,000 to $3,000 per child", + parameter_values=[ + ParameterValue( + parameter=parameter, + start_date=datetime.date(2024, 1, 1), + end_date=datetime.date(2024, 12, 31), + value=3000, + ) + ], +) +``` + +### Making CTC fully refundable + +```python +parameter = Parameter( + name="gov.irs.credits.ctc.refundable.amount.max", + tax_benefit_model_version=us_latest, + description="Maximum refundable CTC", + data_type=float, +) + +policy = Policy( + name="Fully refundable CTC", + description="Makes entire $2,000 CTC refundable", + parameter_values=[ + ParameterValue( + parameter=parameter, + start_date=datetime.date(2024, 1, 1), + end_date=datetime.date(2024, 12, 31), + value=2000, # Match base amount + ) + ], +) +``` + +## State variations + +The US model includes state-level variations for: + +- **State income tax**: Different rates and structures by state +- **State EITC**: State supplements to federal EITC +- **Medicaid**: State-specific eligibility and benefits +- **TANF**: State-administered cash assistance + +### State codes + +Use two-letter state codes (e.g., "CA", "NY", "TX"). All 50 states plus DC are supported. + +## Entity mapping considerations + +The US model's complex entity structure requires careful attention to entity mapping: + +### Person → Household + +When mapping person-level variables (like `ssi`) to household level, values are summed across all household members: + +```python +agg = Aggregate( + simulation=simulation, + variable="ssi", # Person-level + entity="household", # Aggregate to household + aggregate_type=AggregateType.SUM, +) +# Result: Total SSI for all persons in each household +``` + +### Tax unit → Household + +Tax units nest within households. A household may contain multiple tax units (e.g., adult child filing separately): + +```python +agg = Aggregate( + simulation=simulation, + variable="income_tax", # Tax unit level + entity="household", # Aggregate to household + aggregate_type=AggregateType.SUM, +) +# Result: Total income tax for all tax units in each household +``` + +### Household → Person + +Household variables are replicated to all household members: + +```python +# household_net_income at person level +# Each person in household gets the same household_net_income value +``` + +## Data sources + +The US model can use several data sources: + +1. **Current Population Survey (CPS)**: Census Bureau household survey + - ~60,000 households + - Detailed income and demographic data + - Published annually + +2. **Enhanced CPS**: Calibrated and enhanced version + - Uprated to population totals + - Imputed benefit receipt + - Multiple projection years + +3. **Custom datasets**: User-created scenarios + - Full control over household composition + - Exact income levels + - Specific tax filing scenarios + +## Validation + +When creating custom datasets, validate: + +1. **Entity relationships**: All persons link to valid tax_unit, spm_unit, household +2. **Join key naming**: Use `person_household_id`, `person_tax_unit_id`, etc. +3. **Weights**: Appropriate weights for each entity level +4. **State codes**: Valid two-letter codes +5. **Filing status**: Tax units should reflect actual filing patterns + +## Examples + +See working examples in the `examples/` directory: + +- `income_distribution_us.py`: Analyse benefit distribution by income decile +- `employment_income_variation_us.py`: Vary employment income, analyse phase-outs +- `speedtest_us_simulation.py`: Performance benchmarking + +## References + +- PolicyEngine US documentation: https://policyengine.github.io/policyengine-us/ +- IRS tax information: https://www.irs.gov/forms-pubs +- Benefits.gov: https://www.benefits.gov/ +- SPM methodology: https://www.census.gov/topics/income-poverty/supplemental-poverty-measure.html diff --git a/docs/dev.md b/docs/dev.md index 8b7ded2c..accfa48c 100644 --- a/docs/dev.md +++ b/docs/dev.md @@ -5,4 +5,4 @@ General principles for developing this package's codebase go here. 1. **STRONG** preference for simplicity. Let's make this package as simple as it possibly can be. 2. Remember the goal of this package: to make it easy to create, run, save and analyse PolicyEngine simulations. When considering further features, always ask: can we instead *make it super easy* for people to do this outside the package? 3. Be consistent about property names. `name` = human readable few words you could put as the noun in a sentence without fail. `id` = unique identifier, ideally a UUID. `description` = longer human readable text that describes the object. `created_at` and `updated_at` = timestamps for when the object was created and last updated. -4. Constraints can be good. We should set constraints where they help us simplify the codebase and usage, but not where they unnecessarily block useful functionality. For example: a `Model`, e.g. PolicyEngine UK, is restricted to being basically a set of variables, baseline parameters, and a `f: set of tables -> set of tables` function. \ No newline at end of file +4. Constraints can be good. We should set constraints where they help us simplify the codebase and usage, but not where they unnecessarily block useful functionality. \ No newline at end of file diff --git a/docs/myst.yml b/docs/myst.yml index ea5176cf..053152c6 100644 --- a/docs/myst.yml +++ b/docs/myst.yml @@ -11,7 +11,9 @@ project: toc: # Auto-generated by `myst init --write-toc` - file: index.md - - file: quickstart.ipynb + - file: core-concepts.md + - file: country-models-uk.md + - file: country-models-us.md - file: dev.md site: From b76a6ff114cbacea1e7dd170892a5d56574426af Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 10:03:41 +0000 Subject: [PATCH 31/35] Format --- examples/employment_income_variation_uk.py | 22 ++++++-- examples/speedtest_us_simulation.py | 65 +++++++++++++++------- src/policyengine/outputs/aggregate.py | 13 ++++- src/policyengine/utils/plotting.py | 4 +- 4 files changed, 74 insertions(+), 30 deletions(-) diff --git a/examples/employment_income_variation_uk.py b/examples/employment_income_variation_uk.py index 7cd29a28..32212c1b 100644 --- a/examples/employment_income_variation_uk.py +++ b/examples/employment_income_variation_uk.py @@ -221,7 +221,6 @@ def extract_results_by_employment_income( pension_credit = [] income_support = [] - for hh_idx, emp_income in enumerate(employment_incomes): # Get HBAI household net income agg = Aggregate( @@ -267,7 +266,6 @@ def extract_results_by_employment_income( agg.run() benefit_list.append(agg.result) - # Get household tax agg = Aggregate( simulation=simulation, @@ -305,13 +303,19 @@ def visualise_results(results: dict) -> None: # Calculate net employment income (employment income minus tax) net_employment = [ emp - tax - for emp, tax in zip(results["employment_income_hh"], results["household_tax"]) + for emp, tax in zip( + results["employment_income_hh"], results["household_tax"] + ) ] # Stack benefits and income components using PolicyEngine colors components = [ ("Net employment income", net_employment, COLORS["primary"]), - ("Universal Credit", results["universal_credit"], COLORS["blue_secondary"]), + ( + "Universal Credit", + results["universal_credit"], + COLORS["blue_secondary"], + ), ("Working Tax Credit", results["working_tax_credit"], COLORS["info"]), ("Child Tax Credit", results["child_tax_credit"], COLORS["success"]), ("Child Benefit", results["child_benefit"], COLORS["warning"]), @@ -371,9 +375,15 @@ def main(): print("\nSample results:") for emp_inc in [0, 25000, 50000, 100000]: - idx = employment_incomes.index(emp_inc) if emp_inc in employment_incomes else -1 + idx = ( + employment_incomes.index(emp_inc) + if emp_inc in employment_incomes + else -1 + ) if idx >= 0: - print(f" Employment income £{emp_inc:,}: HBAI net income £{results['hbai_household_net_income'][idx]:,.0f}") + print( + f" Employment income £{emp_inc:,}: HBAI net income £{results['hbai_household_net_income'][idx]:,.0f}" + ) print("\nGenerating visualisation...") visualise_results(results) diff --git a/examples/speedtest_us_simulation.py b/examples/speedtest_us_simulation.py index 755a17f9..b7d60ad7 100644 --- a/examples/speedtest_us_simulation.py +++ b/examples/speedtest_us_simulation.py @@ -29,7 +29,9 @@ def create_subset_dataset( tax_unit_df = pd.DataFrame(original_dataset.data.tax_unit).copy() # Sample random households (use n as seed to get different samples for different sizes) - sampled_households = household_df.sample(n=n_households, random_state=n_households).copy() + sampled_households = household_df.sample( + n=n_households, random_state=n_households + ).copy() sampled_household_ids = set(sampled_households["household_id"]) # Determine column naming convention @@ -65,7 +67,9 @@ def create_subset_dataset( ].copy() # Get IDs of group entities that have members in sampled households - sampled_marital_unit_ids = set(sampled_person[marital_unit_id_col].unique()) + sampled_marital_unit_ids = set( + sampled_person[marital_unit_id_col].unique() + ) sampled_family_ids = set(sampled_person[family_id_col].unique()) sampled_spm_unit_ids = set(sampled_person[spm_unit_id_col].unique()) sampled_tax_unit_ids = set(sampled_person[tax_unit_id_col].unique()) @@ -94,7 +98,8 @@ def create_subset_dataset( for new_id, old_id in enumerate(sorted(sampled_marital_unit_ids)) } family_id_map = { - old_id: new_id for new_id, old_id in enumerate(sorted(sampled_family_ids)) + old_id: new_id + for new_id, old_id in enumerate(sorted(sampled_family_ids)) } spm_unit_id_map = { old_id: new_id @@ -110,19 +115,23 @@ def create_subset_dataset( } # Reindex all entity IDs in household table - sampled_households["household_id"] = sampled_households["household_id"].map( - household_id_map - ) + sampled_households["household_id"] = sampled_households[ + "household_id" + ].map(household_id_map) # Reindex all entity IDs in person table - sampled_person["person_id"] = sampled_person["person_id"].map(person_id_map) + sampled_person["person_id"] = sampled_person["person_id"].map( + person_id_map + ) sampled_person[household_id_col] = sampled_person[household_id_col].map( household_id_map ) - sampled_person[marital_unit_id_col] = sampled_person[marital_unit_id_col].map( - marital_unit_id_map + sampled_person[marital_unit_id_col] = sampled_person[ + marital_unit_id_col + ].map(marital_unit_id_map) + sampled_person[family_id_col] = sampled_person[family_id_col].map( + family_id_map ) - sampled_person[family_id_col] = sampled_person[family_id_col].map(family_id_map) sampled_person[spm_unit_id_col] = sampled_person[spm_unit_id_col].map( spm_unit_id_map ) @@ -134,7 +143,9 @@ def create_subset_dataset( sampled_marital_unit["marital_unit_id"] = sampled_marital_unit[ "marital_unit_id" ].map(marital_unit_id_map) - sampled_family["family_id"] = sampled_family["family_id"].map(family_id_map) + sampled_family["family_id"] = sampled_family["family_id"].map( + family_id_map + ) sampled_spm_unit["spm_unit_id"] = sampled_spm_unit["spm_unit_id"].map( spm_unit_id_map ) @@ -143,14 +154,18 @@ def create_subset_dataset( ) # Sort by ID to ensure proper ordering - sampled_households = sampled_households.sort_values("household_id").reset_index( + sampled_households = sampled_households.sort_values( + "household_id" + ).reset_index(drop=True) + sampled_person = sampled_person.sort_values("person_id").reset_index( drop=True ) - sampled_person = sampled_person.sort_values("person_id").reset_index(drop=True) sampled_marital_unit = sampled_marital_unit.sort_values( "marital_unit_id" ).reset_index(drop=True) - sampled_family = sampled_family.sort_values("family_id").reset_index(drop=True) + sampled_family = sampled_family.sort_values("family_id").reset_index( + drop=True + ) sampled_spm_unit = sampled_spm_unit.sort_values("spm_unit_id").reset_index( drop=True ) @@ -176,7 +191,9 @@ def create_subset_dataset( spm_unit=MicroDataFrame( sampled_spm_unit, weights="spm_unit_weight" ), - tax_unit=MicroDataFrame(sampled_tax_unit, weights="tax_unit_weight"), + tax_unit=MicroDataFrame( + sampled_tax_unit, weights="tax_unit_weight" + ), ), ) @@ -199,7 +216,9 @@ def speedtest_simulation(dataset: PolicyEngineUSDataset) -> float: def main(): print("Loading full enhanced CPS dataset...") - dataset_path = Path(__file__).parent / "data" / "enhanced_cps_2024_year_2024.h5" + dataset_path = ( + Path(__file__).parent / "data" / "enhanced_cps_2024_year_2024.h5" + ) if not dataset_path.exists(): raise FileNotFoundError( @@ -219,7 +238,15 @@ def main(): print(f"Full dataset: {total_households:,} households") # Test different subset sizes - test_sizes = [100, 500, 1000, 2500, 5000, 10000, 21532] # Last is full size + test_sizes = [ + 100, + 500, + 1000, + 2500, + 5000, + 10000, + 21532, + ] # Last is full size results = [] @@ -252,9 +279,7 @@ def main(): print("\n" + "=" * 60) print("SPEEDTEST RESULTS") print("=" * 60) - print( - f"{'Households':<12} {'People':<10} {'Duration':<12} {'HH/sec':<10}" - ) + print(f"{'Households':<12} {'People':<10} {'Duration':<12} {'HH/sec':<10}") print("-" * 60) for result in results: diff --git a/src/policyengine/outputs/aggregate.py b/src/policyengine/outputs/aggregate.py index 2619c9e4..b5f3c351 100644 --- a/src/policyengine/outputs/aggregate.py +++ b/src/policyengine/outputs/aggregate.py @@ -34,6 +34,7 @@ class Aggregate(Output): ) result: Any | None = None + def run(self): # Convert quantile specification to describes_quantiles format if self.quantile is not None: @@ -82,7 +83,9 @@ def run(self): if filter_var_obj.entity != target_entity: filter_mapped = ( self.simulation.output_dataset.data.map_to_entity( - filter_var_obj.entity, target_entity, columns=[self.filter_variable] + filter_var_obj.entity, + target_entity, + columns=[self.filter_variable], ) ) filter_series = filter_mapped[self.filter_variable] @@ -94,10 +97,14 @@ def run(self): threshold = filter_series.quantile(self.filter_variable_eq) series = series[filter_series <= threshold] if self.filter_variable_leq is not None: - threshold = filter_series.quantile(self.filter_variable_leq) + threshold = filter_series.quantile( + self.filter_variable_leq + ) series = series[filter_series <= threshold] if self.filter_variable_geq is not None: - threshold = filter_series.quantile(self.filter_variable_geq) + threshold = filter_series.quantile( + self.filter_variable_geq + ) series = series[filter_series >= threshold] else: if self.filter_variable_eq is not None: diff --git a/src/policyengine/utils/plotting.py b/src/policyengine/utils/plotting.py index 661ab19e..500d272b 100644 --- a/src/policyengine/utils/plotting.py +++ b/src/policyengine/utils/plotting.py @@ -20,7 +20,9 @@ } # Typography -FONT_FAMILY = "Inter, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" +FONT_FAMILY = ( + "Inter, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" +) FONT_SIZE_LABEL = 12 FONT_SIZE_DEFAULT = 14 FONT_SIZE_TITLE = 16 From 3099d8707647d27cb9e491e34ab84ca95ff0480f Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 10:09:36 +0000 Subject: [PATCH 32/35] Format --- examples/employment_income_variation_uk.py | 10 ++++--- examples/employment_income_variation_us.py | 9 +++--- examples/income_bands_uk.py | 4 ++- examples/income_distribution_us.py | 8 +++-- examples/policy_change_uk.py | 20 +++++++------ examples/speedtest_us_simulation.py | 6 ++-- src/policyengine/core/__init__.py | 26 +++++++++------- src/policyengine/core/dataset.py | 4 +-- src/policyengine/core/dataset_version.py | 2 +- src/policyengine/core/dynamic.py | 1 + src/policyengine/core/output.py | 7 +++-- src/policyengine/core/simulation.py | 6 ++-- src/policyengine/core/tax_benefit_model.py | 12 ++------ .../core/tax_benefit_model_version.py | 4 +-- src/policyengine/core/variable.py | 3 +- src/policyengine/outputs/aggregate.py | 3 +- src/policyengine/outputs/change_aggregate.py | 3 +- src/policyengine/outputs/decile_impact.py | 5 ++-- src/policyengine/tax_benefit_models/uk.py | 12 +++++++- .../tax_benefit_models/uk/__init__.py | 4 +-- .../tax_benefit_models/uk/analysis.py | 8 +++-- .../tax_benefit_models/uk/datasets.py | 8 +++-- .../tax_benefit_models/uk/model.py | 24 ++++++++------- .../tax_benefit_models/uk/outputs.py | 6 ++-- src/policyengine/tax_benefit_models/us.py | 11 ++++++- .../tax_benefit_models/us/__init__.py | 7 +++-- .../tax_benefit_models/us/analysis.py | 8 +++-- .../tax_benefit_models/us/datasets.py | 8 +++-- .../tax_benefit_models/us/model.py | 30 +++++++++---------- .../tax_benefit_models/us/outputs.py | 6 ++-- src/policyengine/utils/__init__.py | 5 ++-- src/policyengine/utils/dates.py | 2 +- src/policyengine/utils/parametric_reforms.py | 6 ++-- src/policyengine/utils/plotting.py | 13 ++++---- tests/test_aggregate.py | 10 ++++--- tests/test_change_aggregate.py | 18 ++++++----- tests/test_entity_mapping.py | 1 + tests/test_uk_dataset.py | 8 +++-- tests/test_us_datasets.py | 9 +++--- tests/test_us_entity_mapping.py | 1 + tests/test_us_simulation.py | 6 ++-- 41 files changed, 203 insertions(+), 141 deletions(-) diff --git a/examples/employment_income_variation_uk.py b/examples/employment_income_variation_uk.py index 32212c1b..173c78ff 100644 --- a/examples/employment_income_variation_uk.py +++ b/examples/employment_income_variation_uk.py @@ -19,19 +19,21 @@ Run: python examples/employment_income_variation.py """ -import pandas as pd import tempfile from pathlib import Path + +import pandas as pd import plotly.graph_objects as go from microdf import MicroDataFrame + from policyengine.core import Simulation +from policyengine.outputs.aggregate import Aggregate, AggregateType from policyengine.tax_benefit_models.uk import ( PolicyEngineUKDataset, UKYearData, uk_latest, ) -from policyengine.outputs.aggregate import Aggregate, AggregateType -from policyengine.utils.plotting import format_fig, COLORS +from policyengine.utils.plotting import COLORS, format_fig def create_dataset_with_varied_employment_income( @@ -43,7 +45,7 @@ def create_dataset_with_varied_employment_income( Employment income varies across households. """ n_households = len(employment_incomes) - n_people = n_households * 3 # 1 adult + 2 children per household + n_households * 3 # 1 adult + 2 children per household # Create person data - one adult + 2 children per household person_ids = [] diff --git a/examples/employment_income_variation_us.py b/examples/employment_income_variation_us.py index 855c2e88..863d8018 100644 --- a/examples/employment_income_variation_us.py +++ b/examples/employment_income_variation_us.py @@ -10,19 +10,20 @@ Run: python examples/employment_income_variation_us.py """ -import pandas as pd import tempfile from pathlib import Path + +import pandas as pd import plotly.graph_objects as go from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.us import ( PolicyEngineUSDataset, USYearData, us_latest, ) -from policyengine.outputs.aggregate import Aggregate, AggregateType -from policyengine.utils.plotting import format_fig, COLORS +from policyengine.utils.plotting import COLORS, format_fig def create_dataset_with_varied_employment_income( @@ -34,7 +35,7 @@ def create_dataset_with_varied_employment_income( Employment income varies across households. """ n_households = len(employment_incomes) - n_people = n_households * 3 # 1 adult + 2 children per household + n_households * 3 # 1 adult + 2 children per household # Create person data - one adult + 2 children per household person_ids = [] diff --git a/examples/income_bands_uk.py b/examples/income_bands_uk.py index 06f22598..f4f43c72 100644 --- a/examples/income_bands_uk.py +++ b/examples/income_bands_uk.py @@ -10,14 +10,16 @@ """ from pathlib import Path + import plotly.graph_objects as go from plotly.subplots import make_subplots + from policyengine.core import Simulation +from policyengine.outputs.aggregate import Aggregate, AggregateType from policyengine.tax_benefit_models.uk import ( PolicyEngineUKDataset, uk_latest, ) -from policyengine.outputs.aggregate import Aggregate, AggregateType def load_representative_data(year: int = 2026) -> PolicyEngineUKDataset: diff --git a/examples/income_distribution_us.py b/examples/income_distribution_us.py index fdbf48bf..67417e13 100644 --- a/examples/income_distribution_us.py +++ b/examples/income_distribution_us.py @@ -9,17 +9,19 @@ Run: python examples/income_distribution_us.py """ -from pathlib import Path import time +from pathlib import Path + import plotly.graph_objects as go from plotly.subplots import make_subplots + from policyengine.core import Simulation +from policyengine.outputs.aggregate import Aggregate, AggregateType from policyengine.tax_benefit_models.us import ( PolicyEngineUSDataset, us_latest, ) -from policyengine.outputs.aggregate import Aggregate, AggregateType -from policyengine.utils.plotting import format_fig, COLORS +from policyengine.utils.plotting import COLORS, format_fig def load_representative_data(year: int = 2024) -> PolicyEngineUSDataset: diff --git a/examples/policy_change_uk.py b/examples/policy_change_uk.py index 574f37cd..d708448b 100644 --- a/examples/policy_change_uk.py +++ b/examples/policy_change_uk.py @@ -11,19 +11,21 @@ Run: python examples/policy_change.py """ -from pathlib import Path import datetime +from pathlib import Path + import plotly.graph_objects as go from plotly.subplots import make_subplots -from policyengine.core import Simulation, Policy, Parameter, ParameterValue -from policyengine.tax_benefit_models.uk import ( - PolicyEngineUKDataset, - uk_latest, -) + +from policyengine.core import Parameter, ParameterValue, Policy, Simulation from policyengine.outputs.change_aggregate import ( ChangeAggregate, ChangeAggregateType, ) +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + uk_latest, +) def load_representative_data(year: int = 2026) -> PolicyEngineUKDataset: @@ -263,16 +265,16 @@ def print_summary(overall: dict, decile: dict, reform_name: str) -> None: print("=" * 60) print(f"Policy change impact summary: {reform_name}") print("=" * 60) - print(f"\nOverall impact:") + print("\nOverall impact:") print(f" Winners: {overall['winners']:.2f}m households") print(f" Losers: {overall['losers']:.2f}m households") print(f" No change: {overall['no_change']:.2f}m households") - print(f"\nFinancial impact:") + print("\nFinancial impact:") print( f" Net income change: £{overall['total_change']:.2f}bn (negative = loss)" ) print(f" Tax revenue change: £{overall['tax_revenue_change']:.2f}bn") - print(f"\nImpact by income decile:") + print("\nImpact by income decile:") for i, label in enumerate(decile["labels"]): print( f" {label}: {decile['losers'][i]:.2f}m losers, avg change £{decile['avg_loss'][i]:.0f}" diff --git a/examples/speedtest_us_simulation.py b/examples/speedtest_us_simulation.py index b7d60ad7..e0b18fb0 100644 --- a/examples/speedtest_us_simulation.py +++ b/examples/speedtest_us_simulation.py @@ -4,16 +4,18 @@ by running simulations on random subsets of households. """ -from pathlib import Path import time +from pathlib import Path + import pandas as pd +from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.us import ( PolicyEngineUSDataset, USYearData, us_latest, ) -from microdf import MicroDataFrame def create_subset_dataset( diff --git a/src/policyengine/core/__init__.py b/src/policyengine/core/__init__.py index 5ecdcafb..b96e8edd 100644 --- a/src/policyengine/core/__init__.py +++ b/src/policyengine/core/__init__.py @@ -1,14 +1,18 @@ -from .variable import Variable -from .dataset import Dataset, map_to_entity -from .dynamic import Dynamic -from .tax_benefit_model import TaxBenefitModel -from .tax_benefit_model_version import TaxBenefitModelVersion -from .parameter import Parameter -from .parameter_value import ParameterValue -from .policy import Policy -from .simulation import Simulation -from .dataset_version import DatasetVersion -from .output import Output, OutputCollection +from .dataset import Dataset +from .dataset import map_to_entity as map_to_entity +from .dataset_version import DatasetVersion as DatasetVersion +from .dynamic import Dynamic as Dynamic +from .output import Output as Output +from .output import OutputCollection as OutputCollection +from .parameter import Parameter as Parameter +from .parameter_value import ParameterValue as ParameterValue +from .policy import Policy as Policy +from .simulation import Simulation as Simulation +from .tax_benefit_model import TaxBenefitModel as TaxBenefitModel +from .tax_benefit_model_version import ( + TaxBenefitModelVersion as TaxBenefitModelVersion, +) +from .variable import Variable as Variable # Rebuild models to resolve forward references Dataset.model_rebuild() diff --git a/src/policyengine/core/dataset.py b/src/policyengine/core/dataset.py index d73ee71d..a79c0b6d 100644 --- a/src/policyengine/core/dataset.py +++ b/src/policyengine/core/dataset.py @@ -1,11 +1,11 @@ from uuid import uuid4 -from pydantic import BaseModel, Field, ConfigDict import pandas as pd from microdf import MicroDataFrame +from pydantic import BaseModel, ConfigDict, Field -from .tax_benefit_model import TaxBenefitModel from .dataset_version import DatasetVersion +from .tax_benefit_model import TaxBenefitModel class Dataset(BaseModel): diff --git a/src/policyengine/core/dataset_version.py b/src/policyengine/core/dataset_version.py index 29a0150b..711cd7d7 100644 --- a/src/policyengine/core/dataset_version.py +++ b/src/policyengine/core/dataset_version.py @@ -1,7 +1,7 @@ +from typing import TYPE_CHECKING from uuid import uuid4 from pydantic import BaseModel, Field -from typing import TYPE_CHECKING from .tax_benefit_model import TaxBenefitModel diff --git a/src/policyengine/core/dynamic.py b/src/policyengine/core/dynamic.py index 1a88a5f6..9b312952 100644 --- a/src/policyengine/core/dynamic.py +++ b/src/policyengine/core/dynamic.py @@ -3,6 +3,7 @@ from uuid import uuid4 from pydantic import BaseModel, Field + from .parameter_value import ParameterValue diff --git a/src/policyengine/core/output.py b/src/policyengine/core/output.py index 874b694f..a4bf969a 100644 --- a/src/policyengine/core/output.py +++ b/src/policyengine/core/output.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel, ConfigDict +from typing import TypeVar + import pandas as pd -from typing import Generic, TypeVar +from pydantic import BaseModel, ConfigDict T = TypeVar("T", bound="Output") @@ -16,7 +17,7 @@ def run(self): raise NotImplementedError("Subclasses must implement run()") -class OutputCollection(BaseModel, Generic[T]): +class OutputCollection[T: "Output"](BaseModel): """Container for a collection of outputs with their DataFrame representation.""" model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/src/policyengine/core/simulation.py b/src/policyengine/core/simulation.py index 66ece502..1e493b9a 100644 --- a/src/policyengine/core/simulation.py +++ b/src/policyengine/core/simulation.py @@ -1,14 +1,12 @@ from datetime import datetime -from typing import Any, Dict, List from uuid import uuid4 from pydantic import BaseModel, Field from .dataset import Dataset from .dynamic import Dynamic -from .tax_benefit_model import TaxBenefitModel -from .tax_benefit_model_version import TaxBenefitModelVersion from .policy import Policy +from .tax_benefit_model_version import TaxBenefitModelVersion class Simulation(BaseModel): @@ -23,7 +21,7 @@ class Simulation(BaseModel): tax_benefit_model_version: TaxBenefitModelVersion = None output_dataset: Dataset | None = None - variables: Dict[str, List[str]] | None = Field( + variables: dict[str, list[str]] | None = Field( default=None, description="Optional dictionary mapping entity names to lists of variable names to calculate. If None, uses model defaults.", ) diff --git a/src/policyengine/core/tax_benefit_model.py b/src/policyengine/core/tax_benefit_model.py index 255a21b1..02cb94ef 100644 --- a/src/policyengine/core/tax_benefit_model.py +++ b/src/policyengine/core/tax_benefit_model.py @@ -1,17 +1,9 @@ -from collections.abc import Callable -from datetime import datetime from typing import TYPE_CHECKING -from pydantic import BaseModel, Field +from pydantic import BaseModel if TYPE_CHECKING: - from .variable import Variable - from .parameter import Parameter - from .simulation import Simulation - from .dataset import Dataset - from .policy import Policy - from .dynamic import Dynamic - from .parameter_value import ParameterValue + pass class TaxBenefitModel(BaseModel): diff --git a/src/policyengine/core/tax_benefit_model_version.py b/src/policyengine/core/tax_benefit_model_version.py index 0b6da7f3..8555c6f6 100644 --- a/src/policyengine/core/tax_benefit_model_version.py +++ b/src/policyengine/core/tax_benefit_model_version.py @@ -1,16 +1,16 @@ from datetime import datetime +from typing import TYPE_CHECKING from uuid import uuid4 from pydantic import BaseModel, Field from .tax_benefit_model import TaxBenefitModel -from typing import TYPE_CHECKING if TYPE_CHECKING: - from .variable import Variable from .parameter import Parameter from .parameter_value import ParameterValue from .simulation import Simulation + from .variable import Variable class TaxBenefitModelVersion(BaseModel): diff --git a/src/policyengine/core/variable.py b/src/policyengine/core/variable.py index 2428a0a7..24375120 100644 --- a/src/policyengine/core/variable.py +++ b/src/policyengine/core/variable.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel from typing import Any +from pydantic import BaseModel + from .tax_benefit_model_version import TaxBenefitModelVersion diff --git a/src/policyengine/outputs/aggregate.py b/src/policyengine/outputs/aggregate.py index b5f3c351..2d41259c 100644 --- a/src/policyengine/outputs/aggregate.py +++ b/src/policyengine/outputs/aggregate.py @@ -1,7 +1,8 @@ -from policyengine.core import * from enum import Enum from typing import Any +from policyengine.core import Output, Simulation + class AggregateType(str, Enum): SUM = "sum" diff --git a/src/policyengine/outputs/change_aggregate.py b/src/policyengine/outputs/change_aggregate.py index 359ead57..b5bfe2df 100644 --- a/src/policyengine/outputs/change_aggregate.py +++ b/src/policyengine/outputs/change_aggregate.py @@ -1,7 +1,8 @@ -from policyengine.core import * from enum import Enum from typing import Any +from policyengine.core import Output, Simulation + class ChangeAggregateType(str, Enum): COUNT = "count" diff --git a/src/policyengine/outputs/decile_impact.py b/src/policyengine/outputs/decile_impact.py index f2e7837f..8fcc8579 100644 --- a/src/policyengine/outputs/decile_impact.py +++ b/src/policyengine/outputs/decile_impact.py @@ -1,6 +1,7 @@ -from policyengine.core import Simulation, Output, OutputCollection -from pydantic import ConfigDict import pandas as pd +from pydantic import ConfigDict + +from policyengine.core import Output, OutputCollection, Simulation class DecileImpact(Output): diff --git a/src/policyengine/tax_benefit_models/uk.py b/src/policyengine/tax_benefit_models/uk.py index 0056d54f..a9fb102a 100644 --- a/src/policyengine/tax_benefit_models/uk.py +++ b/src/policyengine/tax_benefit_models/uk.py @@ -1,6 +1,16 @@ """PolicyEngine UK tax-benefit model - imports from uk/ module.""" -from .uk import * +from .uk import ( + PolicyEngineUK, + PolicyEngineUKDataset, + PolicyEngineUKLatest, + ProgrammeStatistics, + UKYearData, + create_datasets, + general_policy_reform_analysis, + uk_latest, + uk_model, +) __all__ = [ "UKYearData", diff --git a/src/policyengine/tax_benefit_models/uk/__init__.py b/src/policyengine/tax_benefit_models/uk/__init__.py index f77f3988..ade6e531 100644 --- a/src/policyengine/tax_benefit_models/uk/__init__.py +++ b/src/policyengine/tax_benefit_models/uk/__init__.py @@ -1,8 +1,8 @@ """PolicyEngine UK tax-benefit model.""" -from .datasets import UKYearData, PolicyEngineUKDataset, create_datasets -from .model import PolicyEngineUK, PolicyEngineUKLatest, uk_model, uk_latest from .analysis import general_policy_reform_analysis +from .datasets import PolicyEngineUKDataset, UKYearData, create_datasets +from .model import PolicyEngineUK, PolicyEngineUKLatest, uk_latest, uk_model from .outputs import ProgrammeStatistics __all__ = [ diff --git a/src/policyengine/tax_benefit_models/uk/analysis.py b/src/policyengine/tax_benefit_models/uk/analysis.py index 9573cc52..40805bf2 100644 --- a/src/policyengine/tax_benefit_models/uk/analysis.py +++ b/src/policyengine/tax_benefit_models/uk/analysis.py @@ -1,13 +1,15 @@ """General utility functions for UK policy reform analysis.""" -from policyengine.core import Simulation, OutputCollection +import pandas as pd +from pydantic import BaseModel + +from policyengine.core import OutputCollection, Simulation from policyengine.outputs.decile_impact import ( DecileImpact, calculate_decile_impacts, ) + from .outputs import ProgrammeStatistics -from pydantic import BaseModel -import pandas as pd class PolicyReformAnalysis(BaseModel): diff --git a/src/policyengine/tax_benefit_models/uk/datasets.py b/src/policyengine/tax_benefit_models/uk/datasets.py index 0b94e4ca..113d4b57 100644 --- a/src/policyengine/tax_benefit_models/uk/datasets.py +++ b/src/policyengine/tax_benefit_models/uk/datasets.py @@ -1,8 +1,10 @@ -from policyengine.core import Dataset, map_to_entity -from pydantic import BaseModel, ConfigDict +from pathlib import Path + import pandas as pd from microdf import MicroDataFrame -from pathlib import Path +from pydantic import BaseModel, ConfigDict + +from policyengine.core import Dataset, map_to_entity class UKYearData(BaseModel): diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py index 5a91f5d7..18f1ef25 100644 --- a/src/policyengine/tax_benefit_models/uk/model.py +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -1,19 +1,22 @@ +import datetime +from importlib.metadata import version +from pathlib import Path +from typing import TYPE_CHECKING + +import pandas as pd +import requests +from microdf import MicroDataFrame + from policyengine.core import ( + Parameter, + ParameterValue, TaxBenefitModel, TaxBenefitModelVersion, Variable, - Parameter, - ParameterValue, ) -import datetime -import requests -from importlib.metadata import version from policyengine.utils import parse_safe_date -import pandas as pd -from microdf import MicroDataFrame -from pathlib import Path + from .datasets import PolicyEngineUKDataset, UKYearData -from typing import TYPE_CHECKING if TYPE_CHECKING: from policyengine.core.simulation import Simulation @@ -43,8 +46,8 @@ class PolicyEngineUKLatest(TaxBenefitModelVersion): def __init__(self, **kwargs: dict): super().__init__(**kwargs) - from policyengine_uk.system import system from policyengine_core.enums import Enum + from policyengine_uk.system import system self.id = f"{self.model.id}@{self.version}" @@ -110,6 +113,7 @@ def __init__(self, **kwargs: dict): def run(self, simulation: "Simulation") -> "Simulation": from policyengine_uk import Microsimulation from policyengine_uk.data import UKSingleYearDataset + from policyengine.utils.parametric_reforms import ( simulation_modifier_from_parameter_values, ) diff --git a/src/policyengine/tax_benefit_models/uk/outputs.py b/src/policyengine/tax_benefit_models/uk/outputs.py index 445a37c5..cc9ee82d 100644 --- a/src/policyengine/tax_benefit_models/uk/outputs.py +++ b/src/policyengine/tax_benefit_models/uk/outputs.py @@ -1,13 +1,15 @@ """UK-specific output templates.""" +from typing import TYPE_CHECKING + +from pydantic import ConfigDict + from policyengine.core import Output from policyengine.outputs.aggregate import Aggregate, AggregateType from policyengine.outputs.change_aggregate import ( ChangeAggregate, ChangeAggregateType, ) -from pydantic import ConfigDict -from typing import TYPE_CHECKING if TYPE_CHECKING: from policyengine.core.simulation import Simulation diff --git a/src/policyengine/tax_benefit_models/us.py b/src/policyengine/tax_benefit_models/us.py index 50c7b063..c915a3b5 100644 --- a/src/policyengine/tax_benefit_models/us.py +++ b/src/policyengine/tax_benefit_models/us.py @@ -3,7 +3,16 @@ from importlib.util import find_spec if find_spec("policyengine_us") is not None: - from .us import * + from .us import ( + PolicyEngineUS, + PolicyEngineUSDataset, + PolicyEngineUSLatest, + ProgramStatistics, + USYearData, + general_policy_reform_analysis, + us_latest, + us_model, + ) __all__ = [ "USYearData", diff --git a/src/policyengine/tax_benefit_models/us/__init__.py b/src/policyengine/tax_benefit_models/us/__init__.py index 8c273fa0..63361789 100644 --- a/src/policyengine/tax_benefit_models/us/__init__.py +++ b/src/policyengine/tax_benefit_models/us/__init__.py @@ -4,14 +4,15 @@ if find_spec("policyengine_us") is not None: from policyengine.core import Dataset - from .datasets import USYearData, PolicyEngineUSDataset, create_datasets + + from .analysis import general_policy_reform_analysis + from .datasets import PolicyEngineUSDataset, USYearData, create_datasets from .model import ( PolicyEngineUS, PolicyEngineUSLatest, - us_model, us_latest, + us_model, ) - from .analysis import general_policy_reform_analysis from .outputs import ProgramStatistics # Rebuild Pydantic models to resolve forward references diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index 905749f1..c3098d45 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -1,13 +1,15 @@ """General utility functions for US policy reform analysis.""" -from policyengine.core import Simulation, OutputCollection +import pandas as pd +from pydantic import BaseModel + +from policyengine.core import OutputCollection, Simulation from policyengine.outputs.decile_impact import ( DecileImpact, calculate_decile_impacts, ) + from .outputs import ProgramStatistics -from pydantic import BaseModel -import pandas as pd class PolicyReformAnalysis(BaseModel): diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py index b98497ee..f7b34481 100644 --- a/src/policyengine/tax_benefit_models/us/datasets.py +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -1,8 +1,10 @@ -from policyengine.core import Dataset, map_to_entity -from pydantic import BaseModel, ConfigDict +from pathlib import Path + import pandas as pd from microdf import MicroDataFrame -from pathlib import Path +from pydantic import BaseModel, ConfigDict + +from policyengine.core import Dataset, map_to_entity class USYearData(BaseModel): diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index bb1747f9..5e2068c5 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -1,19 +1,22 @@ +import datetime +from importlib.metadata import version +from pathlib import Path +from typing import TYPE_CHECKING + +import pandas as pd +import requests +from microdf import MicroDataFrame + from policyengine.core import ( + Parameter, + ParameterValue, TaxBenefitModel, TaxBenefitModelVersion, Variable, - Parameter, - ParameterValue, ) -import datetime -import requests -from importlib.metadata import version from policyengine.utils import parse_safe_date -import pandas as pd -from microdf import MicroDataFrame -from pathlib import Path + from .datasets import PolicyEngineUSDataset, USYearData -from typing import TYPE_CHECKING if TYPE_CHECKING: from policyengine.core.simulation import Simulation @@ -50,8 +53,8 @@ def __init__(self, **kwargs: dict): kwargs["created_at"] = datetime.datetime.fromisoformat(upload_time) super().__init__(**kwargs) - from policyengine_us.system import system from policyengine_core.enums import Enum + from policyengine_us.system import system self.id = f"{self.model.id}@{self.version}" @@ -115,13 +118,10 @@ def __init__(self, **kwargs: dict): def run(self, simulation: "Simulation") -> "Simulation": from policyengine_us import Microsimulation from policyengine_us.system import system - from policyengine_core.simulations.simulation_builder import ( - SimulationBuilder, - ) + from policyengine.utils.parametric_reforms import ( simulation_modifier_from_parameter_values, ) - import numpy as np assert isinstance(simulation.dataset, PolicyEngineUSDataset) @@ -322,10 +322,10 @@ def _build_simulation_from_dataset(self, microsim, dataset, system): dataset: The dataset containing entity data system: The tax-benefit system """ + import numpy as np from policyengine_core.simulations.simulation_builder import ( SimulationBuilder, ) - import numpy as np # Create builder and instantiate entities builder = SimulationBuilder() diff --git a/src/policyengine/tax_benefit_models/us/outputs.py b/src/policyengine/tax_benefit_models/us/outputs.py index fb54ed5f..38e20858 100644 --- a/src/policyengine/tax_benefit_models/us/outputs.py +++ b/src/policyengine/tax_benefit_models/us/outputs.py @@ -1,13 +1,15 @@ """US-specific output templates.""" +from typing import TYPE_CHECKING + +from pydantic import ConfigDict + from policyengine.core import Output from policyengine.outputs.aggregate import Aggregate, AggregateType from policyengine.outputs.change_aggregate import ( ChangeAggregate, ChangeAggregateType, ) -from pydantic import ConfigDict -from typing import TYPE_CHECKING if TYPE_CHECKING: from policyengine.core.simulation import Simulation diff --git a/src/policyengine/utils/__init__.py b/src/policyengine/utils/__init__.py index ac764329..e73de67e 100644 --- a/src/policyengine/utils/__init__.py +++ b/src/policyengine/utils/__init__.py @@ -1,2 +1,3 @@ -from .dates import parse_safe_date -from .plotting import format_fig, COLORS +from .dates import parse_safe_date as parse_safe_date +from .plotting import COLORS as COLORS +from .plotting import format_fig as format_fig diff --git a/src/policyengine/utils/dates.py b/src/policyengine/utils/dates.py index d2439456..e3c65fab 100644 --- a/src/policyengine/utils/dates.py +++ b/src/policyengine/utils/dates.py @@ -1,5 +1,5 @@ -from datetime import datetime import calendar +from datetime import datetime def parse_safe_date(date_string: str) -> datetime: diff --git a/src/policyengine/utils/parametric_reforms.py b/src/policyengine/utils/parametric_reforms.py index 88918ec8..7d7a869a 100644 --- a/src/policyengine/utils/parametric_reforms.py +++ b/src/policyengine/utils/parametric_reforms.py @@ -1,7 +1,9 @@ -from policyengine.core import ParameterValue -from typing import Callable +from collections.abc import Callable + from policyengine_core.periods import period +from policyengine.core import ParameterValue + def simulation_modifier_from_parameter_values( parameter_values: list[ParameterValue], diff --git a/src/policyengine/utils/plotting.py b/src/policyengine/utils/plotting.py index 500d272b..77aed94f 100644 --- a/src/policyengine/utils/plotting.py +++ b/src/policyengine/utils/plotting.py @@ -1,8 +1,7 @@ """Plotting utilities for PolicyEngine visualisations.""" -from typing import Optional -import plotly.graph_objects as go +import plotly.graph_objects as go # PolicyEngine brand colours COLORS = { @@ -30,12 +29,12 @@ def format_fig( fig: go.Figure, - title: Optional[str] = None, - xaxis_title: Optional[str] = None, - yaxis_title: Optional[str] = None, + title: str | None = None, + xaxis_title: str | None = None, + yaxis_title: str | None = None, show_legend: bool = True, - height: Optional[int] = None, - width: Optional[int] = None, + height: int | None = None, + width: int | None = None, ) -> go.Figure: """Apply PolicyEngine visual style to a plotly figure. diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index 57c1a0fe..5b4e8b27 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -1,14 +1,16 @@ -import pandas as pd -import tempfile import os +import tempfile + +import pandas as pd from microdf import MicroDataFrame -from policyengine.core import * + +from policyengine.core import Simulation +from policyengine.outputs.aggregate import Aggregate, AggregateType from policyengine.tax_benefit_models.uk import ( PolicyEngineUKDataset, UKYearData, uk_latest, ) -from policyengine.outputs.aggregate import Aggregate, AggregateType def test_aggregate_sum(): diff --git a/tests/test_change_aggregate.py b/tests/test_change_aggregate.py index 6006cbe1..ea900db6 100644 --- a/tests/test_change_aggregate.py +++ b/tests/test_change_aggregate.py @@ -1,17 +1,21 @@ -import pandas as pd -import tempfile import os +import tempfile + +import pandas as pd from microdf import MicroDataFrame -from policyengine.core import * -from policyengine.tax_benefit_models.uk import ( - PolicyEngineUKDataset, - UKYearData, - uk_latest, + +from policyengine.core import ( + Simulation, ) from policyengine.outputs.change_aggregate import ( ChangeAggregate, ChangeAggregateType, ) +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, + uk_latest, +) def test_change_aggregate_count(): diff --git a/tests/test_entity_mapping.py b/tests/test_entity_mapping.py index 48c39e3d..77babd44 100644 --- a/tests/test_entity_mapping.py +++ b/tests/test_entity_mapping.py @@ -1,6 +1,7 @@ import pandas as pd import pytest from microdf import MicroDataFrame + from policyengine.tax_benefit_models.uk import UKYearData diff --git a/tests/test_uk_dataset.py b/tests/test_uk_dataset.py index 5e3f8e68..f8c04453 100644 --- a/tests/test_uk_dataset.py +++ b/tests/test_uk_dataset.py @@ -1,8 +1,10 @@ -import pandas as pd -import tempfile import os +import tempfile + +import pandas as pd from microdf import MicroDataFrame -from policyengine.core import * + +from policyengine.core import Dataset, TaxBenefitModel from policyengine.tax_benefit_models.uk import ( PolicyEngineUKDataset, UKYearData, diff --git a/tests/test_us_datasets.py b/tests/test_us_datasets.py index 6f84c507..08011610 100644 --- a/tests/test_us_datasets.py +++ b/tests/test_us_datasets.py @@ -1,12 +1,13 @@ """Tests for US dataset creation from HuggingFace paths.""" -import pytest -import pandas as pd -from pathlib import Path import shutil +from pathlib import Path + +import pandas as pd + from policyengine.tax_benefit_models.us import ( - create_datasets, PolicyEngineUSDataset, + create_datasets, ) diff --git a/tests/test_us_entity_mapping.py b/tests/test_us_entity_mapping.py index 59fe6913..65fb67fb 100644 --- a/tests/test_us_entity_mapping.py +++ b/tests/test_us_entity_mapping.py @@ -1,6 +1,7 @@ import pandas as pd import pytest from microdf import MicroDataFrame + from policyengine.tax_benefit_models.us import USYearData diff --git a/tests/test_us_simulation.py b/tests/test_us_simulation.py index b3df9a67..4de79691 100644 --- a/tests/test_us_simulation.py +++ b/tests/test_us_simulation.py @@ -1,7 +1,9 @@ -import pandas as pd -import tempfile import os +import tempfile + +import pandas as pd from microdf import MicroDataFrame + from policyengine.core import Simulation from policyengine.tax_benefit_models.us import ( PolicyEngineUSDataset, From 001b19e6455441140862726177d96b1dfb99b33d Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 10:12:59 +0000 Subject: [PATCH 33/35] Remove unused deps --- pyproject.toml | 16 +-- uv.lock | 366 ++----------------------------------------------- 2 files changed, 10 insertions(+), 372 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 58df3c7d..4df10f82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,20 +13,11 @@ authors = [ license = {file = "LICENSE"} requires-python = ">=3.13" dependencies = [ - "sqlalchemy>=2.0.0", - "sqlmodel>=0.0.21", - "alembic>=1.13.0", - "psycopg2-binary>=2.9.0", - "pymysql>=1.1.0", - "google-cloud-storage>=2.10.0", - "getpass4", "pydantic>=2.0.0", "pandas>=2.0.0", - "rich>=13.0.0", - "ipywidgets>=8.0.0", "microdf_python", - "tqdm>=4.67.1", - "blosc>=1.11.3", + "plotly>=5.0.0", + "requests>=2.31.0", ] [project.optional-dependencies] @@ -64,9 +55,6 @@ where = ["src"] [tool.setuptools.package-data] "policyengine" = ["**/*"] -[project.scripts] -pe-migrate = "policyengine.migrations.runner:main" - [tool.pytest.ini_options] addopts = "-v" testpaths = [ diff --git a/uv.lock b/uv.lock index c8d1e963..811c58e8 100644 --- a/uv.lock +++ b/uv.lock @@ -23,20 +23,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/34/d4e1c02d3bee589efb5dfa17f88ea08bdb3e3eac12bc475462aec52ed223/alabaster-0.7.16-py3-none-any.whl", hash = "sha256:b46733c07dce03ae4e150330b975c75737fa60f0a7c591b6c8bf4928a28e2c92", size = 13511, upload-time = "2024-01-10T00:56:08.388Z" }, ] -[[package]] -name = "alembic" -version = "1.16.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mako" }, - { name = "sqlalchemy" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9a/ca/4dc52902cf3491892d464f5265a81e9dff094692c8a049a3ed6a05fe7ee8/alembic-1.16.5.tar.gz", hash = "sha256:a88bb7f6e513bd4301ecf4c7f2206fe93f9913f9b48dac3b78babde2d6fe765e", size = 1969868, upload-time = "2025-08-27T18:02:05.668Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/39/4a/4c61d4c84cfd9befb6fa08a702535b27b21fff08c946bc2f6139decbf7f7/alembic-1.16.5-py3-none-any.whl", hash = "sha256:e845dfe090c5ffa7b92593ae6687c5cb1a101e91fa53868497dbd79847f9dbe3", size = 247355, upload-time = "2025-08-27T18:02:07.37Z" }, -] - [[package]] name = "annotated-types" version = "0.7.0" @@ -137,22 +123,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, ] -[[package]] -name = "blosc" -version = "1.11.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e3/ca/3ec5a5d05e10ad200d887c8cfb9492d9a02e05f9f8f726aa178123b1711b/blosc-1.11.3.tar.gz", hash = "sha256:89ed658eba7814a92e89c44d8c524148d55921595bc133bd1a90f8888a9e088e", size = 1439627, upload-time = "2025-05-17T11:50:03.713Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/23/6ee0e7270ad6299e73483dfad31b17f8acf66f7768094316a35ee0534f1d/blosc-1.11.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9b474c70b9765587323dd1d7ff8e9fa9e9b35ccb3bee77e7658ce9faf2e05f7f", size = 2291576, upload-time = "2025-05-17T11:49:41.013Z" }, - { url = "https://files.pythonhosted.org/packages/51/8f/d8097dd6bf952d4bc1a31852f717d5a1157b32c1bea50dac723ed8e6bc8d/blosc-1.11.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:291d153864f53960861a48c2a5f6706adc2a84a2bdd9c3d1c5353d9c32748a03", size = 1801973, upload-time = "2025-05-17T11:49:42.259Z" }, - { url = "https://files.pythonhosted.org/packages/1e/cb/7fdf0756e6a38d6a28c5063bc8ba8a8c8b1a1ab6980d777c52ca7dd942b1/blosc-1.11.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece67bb34741a147e4120cff3ee3784121709a112d16795716b8f4239aaddfa4", size = 2485043, upload-time = "2025-05-17T11:49:44.034Z" }, - { url = "https://files.pythonhosted.org/packages/7c/b8/d21a1305356312ca0fc6bd54ad6fb91e7434f0efef545972eb72f040c815/blosc-1.11.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e70216dbddb85b69a8d0f62a4a5c09b7a1fce9ca2f329793e799f8b6f9fa3ab0", size = 2619988, upload-time = "2025-05-17T11:49:45.346Z" }, - { url = "https://files.pythonhosted.org/packages/a0/79/9ed273c9493e02f0bc5deacd3854ecabd6c6ba5371ed04b6c7702fd16f77/blosc-1.11.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:220865ffcac638f8f0f4b51259d4e4f3236165e5b43fffd1e836cd7cd29b9367", size = 2678176, upload-time = "2025-05-17T11:49:47.12Z" }, - { url = "https://files.pythonhosted.org/packages/79/0e/c50458a1e038c0f0da70c3223d2a34ad702b86a79d0921f23a8ffaae035f/blosc-1.11.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d57dde8c335378e8443757b69d0b29e90dfc53047d01311e952aecc815167dec", size = 2752740, upload-time = "2025-05-17T11:49:48.909Z" }, - { url = "https://files.pythonhosted.org/packages/f1/0e/3a5ed949e0e23eb576c08017bb39e8612607cf8f591d8149b0fb82469a03/blosc-1.11.3-cp313-cp313-win32.whl", hash = "sha256:d3d72046580a50177811916c78130d6ae7307420733de6e950cb567c896b1ca5", size = 1530991, upload-time = "2025-05-17T11:49:50.121Z" }, - { url = "https://files.pythonhosted.org/packages/06/d4/0c3cdaf34b3ef705fdab465ad8df4a3bce5bbdf2bca8f2515eae90ae28a0/blosc-1.11.3-cp313-cp313-win_amd64.whl", hash = "sha256:73721c1949f2b8d2f4168cababbfe6280511f0da9a971ba7ec9c56eab9603824", size = 1815688, upload-time = "2025-05-17T11:49:51.434Z" }, -] - [[package]] name = "blosc2" version = "3.7.2" @@ -199,24 +169,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/8c/2b30c12155ad8de0cf641d76a8b396a16d2c36bc6d50b621a62b7c4567c1/build-1.3.0-py3-none-any.whl", hash = "sha256:7145f0b5061ba90a1500d60bd1b13ca0a8a4cebdd0cc16ed8adf1c0e739f43b4", size = 23382, upload-time = "2025-08-01T21:27:07.844Z" }, ] -[[package]] -name = "cachetools" -version = "5.5.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380, upload-time = "2025-02-20T21:01:19.524Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080, upload-time = "2025-02-20T21:01:16.647Z" }, -] - -[[package]] -name = "caugetch" -version = "0.0.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a3/ec/519cb37e3e58e23a5b02a74049128f6e701ccd8892b0cebecf701fac6177/caugetch-0.0.1.tar.gz", hash = "sha256:6f6ddb3b928fa272071b02aabb3342941cd99992f27413ba8c189eb4dc3e33b0", size = 2071, upload-time = "2019-10-15T22:39:49.315Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/70/33/64fee4626ec943c2d0c4eee31c784dab8452dfe014916190730880d4ea62/caugetch-0.0.1-py3-none-any.whl", hash = "sha256:ee743dcbb513409cd24cfc42435418073683ba2f4bb7ee9f8440088a47d59277", size = 3439, upload-time = "2019-10-15T22:39:47.122Z" }, -] - [[package]] name = "certifi" version = "2025.8.3" @@ -291,15 +243,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, ] -[[package]] -name = "clipboard" -version = "0.0.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyperclip" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/8a/38/17f3885713d0f39994563029942b1d31c93d4e56d80da505abfbfb3a3bc4/clipboard-0.0.4.tar.gz", hash = "sha256:a72a78e9c9bf68da1c3f29ee022417d13ec9e3824b511559fd2b702b1dd5b817", size = 1713, upload-time = "2014-05-22T12:49:08.683Z" } - [[package]] name = "colorama" version = "0.4.6" @@ -423,120 +366,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/34/2b07b72bee02a63241d654f5d8af87a2de977c59638eec41ca356ab915cd/furo-2025.7.19-py3-none-any.whl", hash = "sha256:bdea869822dfd2b494ea84c0973937e35d1575af088b6721a29c7f7878adc9e3", size = 342175, upload-time = "2025-07-19T10:52:02.399Z" }, ] -[[package]] -name = "getpass4" -version = "0.0.14.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "caugetch" }, - { name = "clipboard" }, - { name = "colorama" }, - { name = "pyperclip" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a2/f9/312f84afc384f693d02eb4ff7306a7268577a8b808aa08f0124c9abba683/getpass4-0.0.14.1.tar.gz", hash = "sha256:80aa4e3a665f2eccc6cda3ee22125eeb5c6338e91c40c4fd010b3c94c7aa4d3a", size = 5078, upload-time = "2021-11-28T17:08:47.276Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0f/d3/ea114aba31f76418b2162e811793cde2e822c9d9ea8ca98d67f9e1f1bde6/getpass4-0.0.14.1-py3-none-any.whl", hash = "sha256:6642c11fb99db1bec90b963e863ec71cdb0b8888000f5089c6377bfbf833f8a9", size = 8683, upload-time = "2021-11-28T17:08:45.468Z" }, -] - -[[package]] -name = "google-api-core" -version = "2.25.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-auth" }, - { name = "googleapis-common-protos" }, - { name = "proto-plus" }, - { name = "protobuf" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/dc/21/e9d043e88222317afdbdb567165fdbc3b0aad90064c7e0c9eb0ad9955ad8/google_api_core-2.25.1.tar.gz", hash = "sha256:d2aaa0b13c78c61cb3f4282c464c046e45fbd75755683c9c525e6e8f7ed0a5e8", size = 165443, upload-time = "2025-06-12T20:52:20.439Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/14/4b/ead00905132820b623732b175d66354e9d3e69fcf2a5dcdab780664e7896/google_api_core-2.25.1-py3-none-any.whl", hash = "sha256:8a2a56c1fef82987a524371f99f3bd0143702fecc670c72e600c1cda6bf8dbb7", size = 160807, upload-time = "2025-06-12T20:52:19.334Z" }, -] - -[[package]] -name = "google-auth" -version = "2.40.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cachetools" }, - { name = "pyasn1-modules" }, - { name = "rsa" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9e/9b/e92ef23b84fa10a64ce4831390b7a4c2e53c0132568d99d4ae61d04c8855/google_auth-2.40.3.tar.gz", hash = "sha256:500c3a29adedeb36ea9cf24b8d10858e152f2412e3ca37829b3fa18e33d63b77", size = 281029, upload-time = "2025-06-04T18:04:57.577Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/17/63/b19553b658a1692443c62bd07e5868adaa0ad746a0751ba62c59568cd45b/google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca", size = 216137, upload-time = "2025-06-04T18:04:55.573Z" }, -] - -[[package]] -name = "google-cloud-core" -version = "2.4.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d6/b8/2b53838d2acd6ec6168fd284a990c76695e84c65deee79c9f3a4276f6b4f/google_cloud_core-2.4.3.tar.gz", hash = "sha256:1fab62d7102844b278fe6dead3af32408b1df3eb06f5c7e8634cbd40edc4da53", size = 35861, upload-time = "2025-03-10T21:05:38.948Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/40/86/bda7241a8da2d28a754aad2ba0f6776e35b67e37c36ae0c45d49370f1014/google_cloud_core-2.4.3-py2.py3-none-any.whl", hash = "sha256:5130f9f4c14b4fafdff75c79448f9495cfade0d8775facf1b09c3bf67e027f6e", size = 29348, upload-time = "2025-03-10T21:05:37.785Z" }, -] - -[[package]] -name = "google-cloud-storage" -version = "3.3.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, - { name = "google-cloud-core" }, - { name = "google-crc32c" }, - { name = "google-resumable-media" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ce/0d/6be1c7e10d1e186e22990fdc22e7ece79f7c622370793cfe88aa8c658316/google_cloud_storage-3.3.1.tar.gz", hash = "sha256:60f291b0881e5c72919b156d1ee276d1b69a2538fcdc35f4e87559ae11678f77", size = 17224623, upload-time = "2025-09-01T05:59:02.804Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/80/67/68eee082fc77e718fa483893ac2463fe0ae8f28ccab334cea9dc5aba99b0/google_cloud_storage-3.3.1-py3-none-any.whl", hash = "sha256:8cace9359b85f315f21868cf771143d6dbb47dcc5a3a9317c8207accc4d10fd3", size = 275070, upload-time = "2025-09-01T05:59:00.633Z" }, -] - -[[package]] -name = "google-crc32c" -version = "1.7.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/19/ae/87802e6d9f9d69adfaedfcfd599266bf386a54d0be058b532d04c794f76d/google_crc32c-1.7.1.tar.gz", hash = "sha256:2bff2305f98846f3e825dbeec9ee406f89da7962accdb29356e4eadc251bd472", size = 14495, upload-time = "2025-03-26T14:29:13.32Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/72/b8d785e9184ba6297a8620c8a37cf6e39b81a8ca01bb0796d7cbb28b3386/google_crc32c-1.7.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:df8b38bdaf1629d62d51be8bdd04888f37c451564c2042d36e5812da9eff3c35", size = 30467, upload-time = "2025-03-26T14:36:06.909Z" }, - { url = "https://files.pythonhosted.org/packages/34/25/5f18076968212067c4e8ea95bf3b69669f9fc698476e5f5eb97d5b37999f/google_crc32c-1.7.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:e42e20a83a29aa2709a0cf271c7f8aefaa23b7ab52e53b322585297bb94d4638", size = 30309, upload-time = "2025-03-26T15:06:15.318Z" }, - { url = "https://files.pythonhosted.org/packages/92/83/9228fe65bf70e93e419f38bdf6c5ca5083fc6d32886ee79b450ceefd1dbd/google_crc32c-1.7.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:905a385140bf492ac300026717af339790921f411c0dfd9aa5a9e69a08ed32eb", size = 33133, upload-time = "2025-03-26T14:41:34.388Z" }, - { url = "https://files.pythonhosted.org/packages/c3/ca/1ea2fd13ff9f8955b85e7956872fdb7050c4ace8a2306a6d177edb9cf7fe/google_crc32c-1.7.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b211ddaf20f7ebeec5c333448582c224a7c90a9d98826fbab82c0ddc11348e6", size = 32773, upload-time = "2025-03-26T14:41:35.19Z" }, - { url = "https://files.pythonhosted.org/packages/89/32/a22a281806e3ef21b72db16f948cad22ec68e4bdd384139291e00ff82fe2/google_crc32c-1.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:0f99eaa09a9a7e642a61e06742856eec8b19fc0037832e03f941fe7cf0c8e4db", size = 33475, upload-time = "2025-03-26T14:29:11.771Z" }, - { url = "https://files.pythonhosted.org/packages/b8/c5/002975aff514e57fc084ba155697a049b3f9b52225ec3bc0f542871dd524/google_crc32c-1.7.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32d1da0d74ec5634a05f53ef7df18fc646666a25efaaca9fc7dcfd4caf1d98c3", size = 33243, upload-time = "2025-03-26T14:41:35.975Z" }, - { url = "https://files.pythonhosted.org/packages/61/cb/c585282a03a0cea70fcaa1bf55d5d702d0f2351094d663ec3be1c6c67c52/google_crc32c-1.7.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e10554d4abc5238823112c2ad7e4560f96c7bf3820b202660373d769d9e6e4c9", size = 32870, upload-time = "2025-03-26T14:41:37.08Z" }, -] - -[[package]] -name = "google-resumable-media" -version = "2.7.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-crc32c" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/58/5a/0efdc02665dca14e0837b62c8a1a93132c264bd02054a15abb2218afe0ae/google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0", size = 2163099, upload-time = "2024-08-07T22:20:38.555Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/82/35/b8d3baf8c46695858cb9d8835a53baa1eeb9906ddaf2f728a5f5b640fd1e/google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa", size = 81251, upload-time = "2024-08-07T22:20:36.409Z" }, -] - -[[package]] -name = "googleapis-common-protos" -version = "1.70.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/39/24/33db22342cf4a2ea27c9955e6713140fedd51e8b141b5ce5260897020f1a/googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257", size = 145903, upload-time = "2025-04-14T10:17:02.924Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, -] - [[package]] name = "greenlet" version = "3.2.4" @@ -551,6 +380,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" }, { url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" }, { url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" }, + { url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" }, + { url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" }, { url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" }, { url = "https://files.pythonhosted.org/packages/22/5c/85273fd7cc388285632b0498dbbab97596e04b154933dfe0f3e68156c68c/greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0", size = 273586, upload-time = "2025-08-07T13:16:08.004Z" }, { url = "https://files.pythonhosted.org/packages/d1/75/10aeeaa3da9332c2e761e4c50d4c3556c21113ee3f0afa2cf5769946f7a3/greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f", size = 686346, upload-time = "2025-08-07T13:42:59.944Z" }, @@ -558,6 +389,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/8b/29aae55436521f1d6f8ff4e12fb676f3400de7fcf27fccd1d4d17fd8fecd/greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1", size = 694659, upload-time = "2025-08-07T13:53:17.759Z" }, { url = "https://files.pythonhosted.org/packages/92/2e/ea25914b1ebfde93b6fc4ff46d6864564fba59024e928bdc7de475affc25/greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735", size = 695355, upload-time = "2025-08-07T13:18:34.517Z" }, { url = "https://files.pythonhosted.org/packages/72/60/fc56c62046ec17f6b0d3060564562c64c862948c9d4bc8aa807cf5bd74f4/greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337", size = 657512, upload-time = "2025-08-07T13:18:33.969Z" }, + { url = "https://files.pythonhosted.org/packages/23/6e/74407aed965a4ab6ddd93a7ded3180b730d281c77b765788419484cdfeef/greenlet-3.2.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2917bdf657f5859fbf3386b12d68ede4cf1f04c90c3a6bc1f013dd68a22e2269", size = 1612508, upload-time = "2025-11-04T12:42:23.427Z" }, + { url = "https://files.pythonhosted.org/packages/0d/da/343cd760ab2f92bac1845ca07ee3faea9fe52bee65f7bcb19f16ad7de08b/greenlet-3.2.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:015d48959d4add5d6c9f6c5210ee3803a830dce46356e3bc326d6776bde54681", size = 1680760, upload-time = "2025-11-04T12:42:25.341Z" }, { url = "https://files.pythonhosted.org/packages/e3/a5/6ddab2b4c112be95601c13428db1d8b6608a8b6039816f2ba09c346c08fc/greenlet-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:e37ab26028f12dbb0ff65f29a8d3d44a765c61e729647bf2ddfbbed621726f01", size = 303425, upload-time = "2025-08-07T13:32:27.59Z" }, ] @@ -694,22 +527,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/d0/274fbf7b0b12643cbbc001ce13e6a5b1607ac4929d1b11c72460152c9fc3/ipython-8.37.0-py3-none-any.whl", hash = "sha256:ed87326596b878932dbcb171e3e698845434d8c61b8d8cd474bf663041a9dcf2", size = 831864, upload-time = "2025-05-31T16:39:06.38Z" }, ] -[[package]] -name = "ipywidgets" -version = "8.1.7" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "comm" }, - { name = "ipython" }, - { name = "jupyterlab-widgets" }, - { name = "traitlets" }, - { name = "widgetsnbextension" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/3e/48/d3dbac45c2814cb73812f98dd6b38bbcc957a4e7bb31d6ea9c03bf94ed87/ipywidgets-8.1.7.tar.gz", hash = "sha256:15f1ac050b9ccbefd45dccfbb2ef6bed0029d8278682d569d71b8dd96bee0376", size = 116721, upload-time = "2025-05-05T12:42:03.489Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/58/6a/9166369a2f092bd286d24e6307de555d63616e8ddb373ebad2b5635ca4cd/ipywidgets-8.1.7-py3-none-any.whl", hash = "sha256:764f2602d25471c213919b8a1997df04bef869251db4ca8efba1b76b1bd9f7bb", size = 139806, upload-time = "2025-05-05T12:41:56.833Z" }, -] - [[package]] name = "itables" version = "2.5.2" @@ -862,15 +679,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2f/57/6bffd4b20b88da3800c5d691e0337761576ee688eb01299eae865689d2df/jupyter_core-5.8.1-py3-none-any.whl", hash = "sha256:c28d268fc90fb53f1338ded2eb410704c5449a358406e8a948b75706e24863d0", size = 28880, upload-time = "2025-05-27T07:38:15.137Z" }, ] -[[package]] -name = "jupyterlab-widgets" -version = "3.0.15" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b9/7d/160595ca88ee87ac6ba95d82177d29ec60aaa63821d3077babb22ce031a5/jupyterlab_widgets-3.0.15.tar.gz", hash = "sha256:2920888a0c2922351a9202817957a68c07d99673504d6cd37345299e971bb08b", size = 213149, upload-time = "2025-05-05T12:32:31.004Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/43/6a/ca128561b22b60bd5a0c4ea26649e68c8556b82bc70a0c396eebc977fe86/jupyterlab_widgets-3.0.15-py3-none-any.whl", hash = "sha256:d59023d7d7ef71400d51e6fee9a88867f6e65e10a4201605d2d7f3e8f012a31c", size = 216571, upload-time = "2025-05-05T12:32:29.534Z" }, -] - [[package]] name = "latexcodec" version = "3.0.1" @@ -892,18 +700,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/1e/b832de447dee8b582cac175871d2f6c3d5077cc56d5575cadba1fd1cccfa/linkify_it_py-2.0.3-py3-none-any.whl", hash = "sha256:6bcbc417b0ac14323382aef5c5192c0075bf8a9d6b41820a2b66371eac6b6d79", size = 19820, upload-time = "2024-02-04T14:48:02.496Z" }, ] -[[package]] -name = "mako" -version = "1.3.10" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "markupsafe" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, -] - [[package]] name = "markdown-it-py" version = "3.0.0" @@ -1287,20 +1083,11 @@ name = "policyengine" version = "3.0.0" source = { editable = "." } dependencies = [ - { name = "alembic" }, - { name = "blosc" }, - { name = "getpass4" }, - { name = "google-cloud-storage" }, - { name = "ipywidgets" }, { name = "microdf-python" }, { name = "pandas" }, - { name = "psycopg2-binary" }, + { name = "plotly" }, { name = "pydantic" }, - { name = "pymysql" }, - { name = "rich" }, - { name = "sqlalchemy" }, - { name = "sqlmodel" }, - { name = "tqdm" }, + { name = "requests" }, ] [package.optional-dependencies] @@ -1330,19 +1117,15 @@ us = [ [package.metadata] requires-dist = [ - { name = "alembic", specifier = ">=1.13.0" }, { name = "autodoc-pydantic", marker = "extra == 'dev'" }, { name = "black", marker = "extra == 'dev'" }, - { name = "blosc", specifier = ">=1.11.3" }, { name = "build", marker = "extra == 'dev'" }, { name = "furo", marker = "extra == 'dev'" }, - { name = "getpass4" }, - { name = "google-cloud-storage", specifier = ">=2.10.0" }, - { name = "ipywidgets", specifier = ">=8.0.0" }, { name = "itables", marker = "extra == 'dev'" }, { name = "jupyter-book", marker = "extra == 'dev'" }, { name = "microdf-python" }, { name = "pandas", specifier = ">=2.0.0" }, + { name = "plotly", specifier = ">=5.0.0" }, { name = "policyengine-core", marker = "extra == 'dev'", specifier = ">=3.10" }, { name = "policyengine-core", marker = "extra == 'uk'", specifier = ">=3.10" }, { name = "policyengine-core", marker = "extra == 'us'", specifier = ">=3.10" }, @@ -1350,16 +1133,11 @@ requires-dist = [ { name = "policyengine-uk", marker = "extra == 'uk'", specifier = ">=2.51.0" }, { name = "policyengine-us", marker = "extra == 'dev'", specifier = ">=1.213.1" }, { name = "policyengine-us", marker = "extra == 'us'", specifier = ">=1.213.1" }, - { name = "psycopg2-binary", specifier = ">=2.9.0" }, { name = "pydantic", specifier = ">=2.0.0" }, - { name = "pymysql", specifier = ">=1.1.0" }, { name = "pytest", marker = "extra == 'dev'" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.26.0" }, - { name = "rich", specifier = ">=13.0.0" }, + { name = "requests", specifier = ">=2.31.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.5.0" }, - { name = "sqlalchemy", specifier = ">=2.0.0" }, - { name = "sqlmodel", specifier = ">=0.0.21" }, - { name = "tqdm", specifier = ">=4.67.1" }, { name = "yaml-changelog", marker = "extra == 'dev'", specifier = ">=0.1.7" }, ] provides-extras = ["uk", "us", "dev"] @@ -1432,32 +1210,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, ] -[[package]] -name = "proto-plus" -version = "1.26.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f4/ac/87285f15f7cce6d4a008f33f1757fb5a13611ea8914eb58c3d0d26243468/proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012", size = 56142, upload-time = "2025-03-10T15:54:38.843Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4e/6d/280c4c2ce28b1593a19ad5239c8b826871fc6ec275c21afc8e1820108039/proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66", size = 50163, upload-time = "2025-03-10T15:54:37.335Z" }, -] - -[[package]] -name = "protobuf" -version = "6.32.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c0/df/fb4a8eeea482eca989b51cffd274aac2ee24e825f0bf3cbce5281fa1567b/protobuf-6.32.0.tar.gz", hash = "sha256:a81439049127067fc49ec1d36e25c6ee1d1a2b7be930675f919258d03c04e7d2", size = 440614, upload-time = "2025-08-14T21:21:25.015Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/33/18/df8c87da2e47f4f1dcc5153a81cd6bca4e429803f4069a299e236e4dd510/protobuf-6.32.0-cp310-abi3-win32.whl", hash = "sha256:84f9e3c1ff6fb0308dbacb0950d8aa90694b0d0ee68e75719cb044b7078fe741", size = 424409, upload-time = "2025-08-14T21:21:12.366Z" }, - { url = "https://files.pythonhosted.org/packages/e1/59/0a820b7310f8139bd8d5a9388e6a38e1786d179d6f33998448609296c229/protobuf-6.32.0-cp310-abi3-win_amd64.whl", hash = "sha256:a8bdbb2f009cfc22a36d031f22a625a38b615b5e19e558a7b756b3279723e68e", size = 435735, upload-time = "2025-08-14T21:21:15.046Z" }, - { url = "https://files.pythonhosted.org/packages/cc/5b/0d421533c59c789e9c9894683efac582c06246bf24bb26b753b149bd88e4/protobuf-6.32.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d52691e5bee6c860fff9a1c86ad26a13afbeb4b168cd4445c922b7e2cf85aaf0", size = 426449, upload-time = "2025-08-14T21:21:16.687Z" }, - { url = "https://files.pythonhosted.org/packages/ec/7b/607764ebe6c7a23dcee06e054fd1de3d5841b7648a90fd6def9a3bb58c5e/protobuf-6.32.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:501fe6372fd1c8ea2a30b4d9be8f87955a64d6be9c88a973996cef5ef6f0abf1", size = 322869, upload-time = "2025-08-14T21:21:18.282Z" }, - { url = "https://files.pythonhosted.org/packages/40/01/2e730bd1c25392fc32e3268e02446f0d77cb51a2c3a8486b1798e34d5805/protobuf-6.32.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:75a2aab2bd1aeb1f5dc7c5f33bcb11d82ea8c055c9becbb41c26a8c43fd7092c", size = 322009, upload-time = "2025-08-14T21:21:19.893Z" }, - { url = "https://files.pythonhosted.org/packages/9c/f2/80ffc4677aac1bc3519b26bc7f7f5de7fce0ee2f7e36e59e27d8beb32dd1/protobuf-6.32.0-py3-none-any.whl", hash = "sha256:ba377e5b67b908c8f3072a57b63e2c6a4cbd18aea4ed98d2584350dbf46f2783", size = 169287, upload-time = "2025-08-14T21:21:23.515Z" }, -] - [[package]] name = "psutil" version = "6.1.1" @@ -1473,25 +1225,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444, upload-time = "2024-12-19T18:22:11.335Z" }, ] -[[package]] -name = "psycopg2-binary" -version = "2.9.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cb/0e/bdc8274dc0585090b4e3432267d7be4dfbfd8971c0fa59167c711105a6bf/psycopg2-binary-2.9.10.tar.gz", hash = "sha256:4b3df0e6990aa98acda57d983942eff13d824135fe2250e6522edaa782a06de2", size = 385764, upload-time = "2024-10-16T11:24:58.126Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3e/30/d41d3ba765609c0763505d565c4d12d8f3c79793f0d0f044ff5a28bf395b/psycopg2_binary-2.9.10-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:26540d4a9a4e2b096f1ff9cce51253d0504dca5a85872c7f7be23be5a53eb18d", size = 3044699, upload-time = "2024-10-16T11:21:42.841Z" }, - { url = "https://files.pythonhosted.org/packages/35/44/257ddadec7ef04536ba71af6bc6a75ec05c5343004a7ec93006bee66c0bc/psycopg2_binary-2.9.10-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e217ce4d37667df0bc1c397fdcd8de5e81018ef305aed9415c3b093faaeb10fb", size = 3275245, upload-time = "2024-10-16T11:21:51.989Z" }, - { url = "https://files.pythonhosted.org/packages/1b/11/48ea1cd11de67f9efd7262085588790a95d9dfcd9b8a687d46caf7305c1a/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:245159e7ab20a71d989da00f280ca57da7641fa2cdcf71749c193cea540a74f7", size = 2851631, upload-time = "2024-10-16T11:21:57.584Z" }, - { url = "https://files.pythonhosted.org/packages/62/e0/62ce5ee650e6c86719d621a761fe4bc846ab9eff8c1f12b1ed5741bf1c9b/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c4ded1a24b20021ebe677b7b08ad10bf09aac197d6943bfe6fec70ac4e4690d", size = 3082140, upload-time = "2024-10-16T11:22:02.005Z" }, - { url = "https://files.pythonhosted.org/packages/27/ce/63f946c098611f7be234c0dd7cb1ad68b0b5744d34f68062bb3c5aa510c8/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3abb691ff9e57d4a93355f60d4f4c1dd2d68326c968e7db17ea96df3c023ef73", size = 3264762, upload-time = "2024-10-16T11:22:06.412Z" }, - { url = "https://files.pythonhosted.org/packages/43/25/c603cd81402e69edf7daa59b1602bd41eb9859e2824b8c0855d748366ac9/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8608c078134f0b3cbd9f89b34bd60a943b23fd33cc5f065e8d5f840061bd0673", size = 3020967, upload-time = "2024-10-16T11:22:11.583Z" }, - { url = "https://files.pythonhosted.org/packages/5f/d6/8708d8c6fca531057fa170cdde8df870e8b6a9b136e82b361c65e42b841e/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:230eeae2d71594103cd5b93fd29d1ace6420d0b86f4778739cb1a5a32f607d1f", size = 2872326, upload-time = "2024-10-16T11:22:16.406Z" }, - { url = "https://files.pythonhosted.org/packages/ce/ac/5b1ea50fc08a9df82de7e1771537557f07c2632231bbab652c7e22597908/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909", size = 2822712, upload-time = "2024-10-16T11:22:21.366Z" }, - { url = "https://files.pythonhosted.org/packages/c4/fc/504d4503b2abc4570fac3ca56eb8fed5e437bf9c9ef13f36b6621db8ef00/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1", size = 2920155, upload-time = "2024-10-16T11:22:25.684Z" }, - { url = "https://files.pythonhosted.org/packages/b2/d1/323581e9273ad2c0dbd1902f3fb50c441da86e894b6e25a73c3fda32c57e/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567", size = 2959356, upload-time = "2024-10-16T11:22:30.562Z" }, - { url = "https://files.pythonhosted.org/packages/08/50/d13ea0a054189ae1bc21af1d85b6f8bb9bbc5572991055d70ad9006fe2d6/psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142", size = 2569224, upload-time = "2025-01-04T20:09:19.234Z" }, -] - [[package]] name = "ptyprocess" version = "0.7.0" @@ -1519,27 +1252,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, ] -[[package]] -name = "pyasn1" -version = "0.6.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322, upload-time = "2024-09-10T22:41:42.55Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" }, -] - -[[package]] -name = "pyasn1-modules" -version = "0.4.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, -] - [[package]] name = "pybtex" version = "0.25.1" @@ -1660,21 +1372,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] -[[package]] -name = "pymysql" -version = "1.1.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f5/ae/1fe3fcd9f959efa0ebe200b8de88b5a5ce3e767e38c7ac32fb179f16a388/pymysql-1.1.2.tar.gz", hash = "sha256:4961d3e165614ae65014e361811a724e2044ad3ea3739de9903ae7c21f539f03", size = 48258, upload-time = "2025-08-24T12:55:55.146Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7c/4c/ad33b92b9864cbde84f259d5df035a6447f91891f5be77788e2a3892bce3/pymysql-1.1.2-py3-none-any.whl", hash = "sha256:e6b1d89711dd51f8f74b1631fe08f039e7d76cf67a42a323d3178f0f25762ed9", size = 45300, upload-time = "2025-08-24T12:55:53.394Z" }, -] - -[[package]] -name = "pyperclip" -version = "1.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/30/23/2f0a3efc4d6a32f3b63cdff36cd398d9701d26cda58e3ab97ac79fb5e60d/pyperclip-1.9.0.tar.gz", hash = "sha256:b7de0142ddc81bfc5c7507eea19da920b92252b548b96186caf94a5e2527d310", size = 20961, upload-time = "2024-06-18T20:38:48.401Z" } - [[package]] name = "pyproject-hooks" version = "1.2.0" @@ -1857,19 +1554,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] -[[package]] -name = "rich" -version = "14.1.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "markdown-it-py" }, - { name = "pygments" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fe/75/af448d8e52bf1d8fa6a9d089ca6c07ff4453d86c65c145d0a300bb073b9b/rich-14.1.0.tar.gz", hash = "sha256:e497a48b844b0320d45007cdebfeaeed8db2a4f4bcf49f15e455cfc4af11eaa8", size = 224441, upload-time = "2025-07-25T07:32:58.125Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/30/3c4d035596d3cf444529e0b2953ad0466f6049528a879d27534700580395/rich-14.1.0-py3-none-any.whl", hash = "sha256:536f5f1785986d6dbdea3c75205c473f970777b4a0d6c6dd1b696aa05a3fa04f", size = 243368, upload-time = "2025-07-25T07:32:56.73Z" }, -] - [[package]] name = "rpds-py" version = "0.27.1" @@ -1936,18 +1620,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/7d/97119da51cb1dd3f2f3c0805f155a3aa4a95fa44fe7d78ae15e69edf4f34/rpds_py-0.27.1-cp314-cp314t-win_amd64.whl", hash = "sha256:6567d2bb951e21232c2f660c24cf3470bb96de56cdcb3f071a83feeaff8a2772", size = 230097, upload-time = "2025-08-27T12:15:03.961Z" }, ] -[[package]] -name = "rsa" -version = "4.9.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, -] - [[package]] name = "ruff" version = "0.12.11" @@ -2263,19 +1935,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/d9/13bdde6521f322861fab67473cec4b1cc8999f3871953531cf61945fad92/sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc", size = 1924759, upload-time = "2025-08-11T15:39:53.024Z" }, ] -[[package]] -name = "sqlmodel" -version = "0.0.24" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pydantic" }, - { name = "sqlalchemy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/86/4b/c2ad0496f5bdc6073d9b4cef52be9c04f2b37a5773441cc6600b1857648b/sqlmodel-0.0.24.tar.gz", hash = "sha256:cc5c7613c1a5533c9c7867e1aab2fd489a76c9e8a061984da11b4e613c182423", size = 116780, upload-time = "2025-03-07T05:43:32.887Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/16/91/484cd2d05569892b7fef7f5ceab3bc89fb0f8a8c0cde1030d383dbc5449c/sqlmodel-0.0.24-py3-none-any.whl", hash = "sha256:6778852f09370908985b667d6a3ab92910d0d5ec88adcaf23dbc242715ff7193", size = 28622, upload-time = "2025-03-07T05:43:30.37Z" }, -] - [[package]] name = "stack-data" version = "0.6.3" @@ -2444,15 +2103,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0b/2c/87f3254fd8ffd29e4c02732eee68a83a1d3c346ae39bc6822dcbcb697f2b/wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248", size = 72494, upload-time = "2024-11-23T00:18:21.207Z" }, ] -[[package]] -name = "widgetsnbextension" -version = "4.0.14" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/41/53/2e0253c5efd69c9656b1843892052a31c36d37ad42812b5da45c62191f7e/widgetsnbextension-4.0.14.tar.gz", hash = "sha256:a3629b04e3edb893212df862038c7232f62973373869db5084aed739b437b5af", size = 1097428, upload-time = "2025-04-10T13:01:25.628Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ca/51/5447876806d1088a0f8f71e16542bf350918128d0a69437df26047c8e46f/widgetsnbextension-4.0.14-py3-none-any.whl", hash = "sha256:4875a9eaf72fbf5079dc372a51a9f268fc38d46f767cbf85c43a36da5cb9b575", size = 2196503, upload-time = "2025-04-10T13:01:23.086Z" }, -] - [[package]] name = "yaml-changelog" version = "0.3.0" From 048efedb7d28997fe681964ccebb6108ffa7a614 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 10:15:32 +0000 Subject: [PATCH 34/35] Suppress warning --- .../tax_benefit_models/us/datasets.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py index f7b34481..80933562 100644 --- a/src/policyengine/tax_benefit_models/us/datasets.py +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -1,3 +1,4 @@ +import warnings from pathlib import Path import pandas as pd @@ -74,13 +75,19 @@ def save(self) -> None: filepath = Path(self.filepath) if not filepath.parent.exists(): filepath.parent.mkdir(parents=True, exist_ok=True) - with pd.HDFStore(filepath, mode="w") as store: - store["person"] = pd.DataFrame(self.data.person) - store["marital_unit"] = pd.DataFrame(self.data.marital_unit) - store["family"] = pd.DataFrame(self.data.family) - store["spm_unit"] = pd.DataFrame(self.data.spm_unit) - store["tax_unit"] = pd.DataFrame(self.data.tax_unit) - store["household"] = pd.DataFrame(self.data.household) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=pd.errors.PerformanceWarning, + message=".*PyTables will pickle object types.*", + ) + with pd.HDFStore(filepath, mode="w") as store: + store["person"] = pd.DataFrame(self.data.person) + store["marital_unit"] = pd.DataFrame(self.data.marital_unit) + store["family"] = pd.DataFrame(self.data.family) + store["spm_unit"] = pd.DataFrame(self.data.spm_unit) + store["tax_unit"] = pd.DataFrame(self.data.tax_unit) + store["household"] = pd.DataFrame(self.data.household) def load(self) -> None: """Load dataset from HDF5 file into this instance.""" From 6a955800f724c51d939a719164b7e74c8c73d289 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 10:25:22 +0000 Subject: [PATCH 35/35] Minor fix --- src/policyengine/tax_benefit_models/us/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py index 80933562..676e08e3 100644 --- a/src/policyengine/tax_benefit_models/us/datasets.py +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -58,7 +58,7 @@ class PolicyEngineUSDataset(Dataset): data: USYearData | None = None - def model_post_init(self, __context): + def model_post_init(self, __context) -> None: """Called after Pydantic initialization.""" # Make sure we are synchronised between in-memory and storage, at least on initialisation if self.data is not None: