diff --git a/.claude/skills/database-migrations.md b/.claude/skills/database-migrations.md new file mode 100644 index 0000000..fedbef8 --- /dev/null +++ b/.claude/skills/database-migrations.md @@ -0,0 +1,301 @@ +# Database Migration Guidelines + +## Overview + +This project uses **Alembic** for database migrations with **SQLModel** models. Alembic is the industry-standard migration tool for SQLAlchemy/SQLModel projects. + +**CRITICAL**: SQL migrations are the single source of truth for database schema. All table creation and schema changes MUST go through Alembic migrations. + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ SQLModel Models (src/policyengine_api/models/) │ +│ - Define Python classes │ +│ - Used for ORM queries │ +│ - NOT the source of truth for schema │ +└─────────────────────────────────────────────────────────────┘ + │ + │ alembic revision --autogenerate + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Alembic Migrations (alembic/versions/) │ +│ - Create/alter tables │ +│ - Add indexes, constraints │ +│ - SOURCE OF TRUTH for schema │ +└─────────────────────────────────────────────────────────────┘ + │ + │ alembic upgrade head + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ PostgreSQL Database (Supabase) │ +│ - Actual schema │ +│ - Tracked by alembic_version table │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Essential Rules + +### 1. NEVER use SQLModel.metadata.create_all() for schema creation + +The old pattern of using `SQLModel.metadata.create_all()` is deprecated. All tables are created via Alembic migrations. + +### 2. Every schema change requires a migration + +When you modify a SQLModel model (add column, change type, add index), you MUST: +1. Update the model in `src/policyengine_api/models/` +2. Generate a migration: `uv run alembic revision --autogenerate -m "Description"` +3. **Read and verify the generated migration** (see below) +4. Apply it: `uv run alembic upgrade head` + +### 3. ALWAYS verify auto-generated migrations before applying + +**This is critical for AI agents.** After running `alembic revision --autogenerate`, you MUST: + +1. **Read the generated migration file** in `alembic/versions/` +2. **Verify the `upgrade()` function** contains the expected changes: + - Correct table/column names + - Correct column types (e.g., `sa.String()`, `sa.Uuid()`, `sa.Integer()`) + - Proper foreign key references + - Appropriate nullable settings +3. **Verify the `downgrade()` function** properly reverses the changes +4. **Check for Alembic autogenerate limitations:** + - It may miss renamed columns (shows as drop + add instead) + - It may not detect some index changes + - It doesn't handle data migrations +5. **Edit the migration if needed** before applying + +Example verification: +```python +# Generated migration - verify this looks correct: +def upgrade() -> None: + op.add_column('users', sa.Column('phone', sa.String(), nullable=True)) + +def downgrade() -> None: + op.drop_column('users', 'phone') +``` + +**Never blindly apply a migration without reading it first.** + +### 4. Migrations must be self-contained + +Each migration should: +- Create tables it needs (never assume they exist from Python) +- Include both `upgrade()` and `downgrade()` functions +- Be idempotent where possible (use `IF NOT EXISTS` patterns) + +### 5. Never use conditional logic based on table existence + +Migrations should NOT check if tables exist. Instead: +- Ensure migrations run in the correct order (use `down_revision`) +- The initial migration creates all base tables +- Subsequent migrations build on that foundation + +## Common Commands + +```bash +# Apply all pending migrations +uv run alembic upgrade head + +# Generate migration from model changes +uv run alembic revision --autogenerate -m "Add users email index" + +# Create empty migration (for manual SQL) +uv run alembic revision -m "Add custom index" + +# Check current migration state +uv run alembic current + +# Show migration history +uv run alembic history + +# Downgrade one revision +uv run alembic downgrade -1 + +# Downgrade to specific revision +uv run alembic downgrade +``` + +## Local Development Workflow + +```bash +# 1. Start Supabase +supabase start + +# 2. Initialize database (runs migrations + applies RLS policies) +uv run python scripts/init.py + +# 3. Seed data +uv run python scripts/seed.py +``` + +### Reset database (DESTRUCTIVE) + +```bash +uv run python scripts/init.py --reset +``` + +## Adding a New Model + +1. Create the model in `src/policyengine_api/models/` + +```python +# src/policyengine_api/models/my_model.py +from sqlmodel import SQLModel, Field +from uuid import UUID, uuid4 + +class MyModel(SQLModel, table=True): + __tablename__ = "my_models" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + name: str +``` + +2. Export in `__init__.py`: + +```python +# src/policyengine_api/models/__init__.py +from .my_model import MyModel +``` + +3. Generate migration: + +```bash +uv run alembic revision --autogenerate -m "Add my_models table" +``` + +4. Review the generated migration in `alembic/versions/` + +5. Apply the migration: + +```bash +uv run alembic upgrade head +``` + +6. Update `scripts/init.py` to include the table in RLS policies if needed. + +## Adding an Index + +1. Generate a migration: + +```bash +uv run alembic revision -m "Add index on users.email" +``` + +2. Edit the migration: + +```python +def upgrade() -> None: + op.create_index("idx_users_email", "users", ["email"]) + +def downgrade() -> None: + op.drop_index("idx_users_email", "users") +``` + +3. Apply: + +```bash +uv run alembic upgrade head +``` + +## Production Considerations + +### Applying migrations to production + +1. Migrations are automatically applied when deploying +2. Always test migrations locally first +3. For data migrations, consider running during low-traffic periods + +### Transitioning production from old system to Alembic + +Production databases that were created before Alembic (using the old `SQLModel.metadata.create_all()` approach or raw Supabase migrations) need special handling. Running `alembic upgrade head` would fail because the tables already exist. + +**The solution: `alembic stamp`** + +The `alembic stamp` command marks a migration as "already applied" without actually running it. This tells Alembic "the database is already at this state, start tracking from here." + +**How it works:** + +1. `alembic stamp ` inserts a row into the `alembic_version` table with the specified revision ID +2. Alembic now thinks that migration (and all migrations before it) have been applied +3. Future migrations will run normally starting from that point + +**Step-by-step production transition:** + +```bash +# 1. Connect to production database +# (set SUPABASE_DB_URL or other connection env vars) + +# 2. Check if alembic_version table exists +# If not, Alembic will create it automatically + +# 3. Verify production schema matches the initial migration +# Compare tables/columns in production against alembic/versions/20260204_d6e30d3b834d_initial_schema.py + +# 4. Stamp the initial migration as applied +uv run alembic stamp d6e30d3b834d + +# 5. If production also has the indexes from the second migration, stamp that too +uv run alembic stamp a17ac554f4aa + +# 6. Verify the stamp worked +uv run alembic current +# Should show: a17ac554f4aa (head) + +# 7. From now on, new migrations will apply normally +uv run alembic upgrade head +``` + +**Handling partially applied migrations:** + +If production has some but not all changes from a migration: + +1. Manually apply the missing changes via SQL +2. Then stamp that migration as complete +3. Or: create a new migration that only adds the missing pieces + +**After stamping:** + +- All future schema changes go through Alembic migrations +- Developers generate migrations with `alembic revision --autogenerate` +- Deployments run `alembic upgrade head` to apply pending migrations +- The `alembic_version` table tracks what's been applied + +## File Structure + +``` +alembic/ +├── env.py # Alembic configuration (imports models, sets DB URL) +├── script.py.mako # Template for new migrations +├── versions/ # Migration files +│ ├── 20260204_d6e30d3b834d_initial_schema.py +│ └── 20260204_a17ac554f4aa_add_parameter_values_indexes.py +alembic.ini # Alembic settings + +supabase/ +├── migrations/ # Supabase-specific migrations (storage only) +│ ├── 20241119000000_storage_bucket.sql +│ └── 20241121000000_storage_policies.sql +└── migrations_archived/ # Old table migrations (now in Alembic) +``` + +## Troubleshooting + +### "Target database is not up to date" + +Run `alembic upgrade head` to apply pending migrations. + +### "Can't locate revision" + +The alembic_version table has a revision that doesn't exist in your migrations folder. This can happen if someone deleted a migration file. Fix by stamping to a known revision: + +```bash +alembic stamp head # If tables are current +alembic stamp d6e30d3b834d # If at initial schema +``` + +### "Table already exists" + +The migration is trying to create a table that already exists. Options: +1. If this is a fresh setup, drop and recreate: `uv run python scripts/init.py --reset` +2. If in production, stamp the migration as applied: `alembic stamp ` diff --git a/CLAUDE.md b/CLAUDE.md index 2df55fc..d6fb240 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -75,7 +75,21 @@ Use `gh` CLI for GitHub operations to ensure Actions run correctly. ## Database -`make init` resets tables and storage. `make seed` populates UK/US models with variables, parameters, and datasets. +This project uses **Alembic** for database migrations. See `.claude/skills/database-migrations.md` for detailed guidelines. + +**Key rules:** +- All schema changes go through Alembic migrations (never use `SQLModel.metadata.create_all()`) +- After modifying a model: `uv run alembic revision --autogenerate -m "Description"` +- Apply migrations: `uv run alembic upgrade head` + +**Local development:** +```bash +supabase start # Start local Supabase +uv run python scripts/init.py # Run migrations + apply RLS policies +uv run python scripts/seed.py # Seed data +``` + +`scripts/init.py --reset` drops and recreates everything (destructive). ## Modal sandbox + Claude Code CLI gotchas diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..ed54635 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,145 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names +# Prepend with date for easier chronological ordering +file_template = %%(year)d%%(month).2d%%(day).2d_%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL - This is overridden by env.py which reads from application settings. +# The placeholder below is only used if env.py doesn't set it. +sqlalchemy.url = postgresql://placeholder:placeholder@localhost/placeholder + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# NOTE: ruff is in dev dependencies, so this hook only works when dev deps are installed +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..f930498 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,87 @@ +"""Alembic environment configuration for SQLModel migrations. + +This module configures Alembic to: +1. Use the database URL from application settings +2. Import all SQLModel models for autogenerate support +3. Run migrations in both offline and online modes +""" + +import sys +from logging.config import fileConfig +from pathlib import Path + +from sqlalchemy import engine_from_config, pool +from sqlmodel import SQLModel + +from alembic import context + +# Add src to path so we can import policyengine_api +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +# Import all models to register them with SQLModel.metadata +# This is required for autogenerate to detect model changes +from policyengine_api import models # noqa: F401 +from policyengine_api.config.settings import settings + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Override sqlalchemy.url with the actual database URL from settings +config.set_main_option("sqlalchemy.url", settings.database_url) + +# Interpret the config file for Python logging. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# SQLModel metadata for autogenerate support +# This allows Alembic to detect changes in your SQLModel models +target_metadata = SQLModel.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..1101630 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/20260204_0001_initial_schema.py b/alembic/versions/20260204_0001_initial_schema.py new file mode 100644 index 0000000..273124a --- /dev/null +++ b/alembic/versions/20260204_0001_initial_schema.py @@ -0,0 +1,537 @@ +"""Initial schema (main branch state) + +Revision ID: 0001_initial +Revises: +Create Date: 2026-02-04 + +This migration creates all base tables for the PolicyEngine API as they +exist on the main branch, BEFORE the household CRUD changes. +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0001_initial" +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create all tables as they exist on main branch.""" + # ======================================================================== + # TIER 1: Tables with no foreign key dependencies + # ======================================================================== + + # Tax benefit models (e.g., "uk", "us") + op.create_table( + "tax_benefit_models", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + + # Users + op.create_table( + "users", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("first_name", sa.String(), nullable=False), + sa.Column("last_name", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("email"), + ) + op.create_index("ix_users_email", "users", ["email"]) + + # Policies (reform definitions) + op.create_table( + "policies", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + + # Dynamics (behavioral response definitions) + op.create_table( + "dynamics", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + + # ======================================================================== + # TIER 2: Tables depending on tier 1 + # ======================================================================== + + # Tax benefit model versions + op.create_table( + "tax_benefit_model_versions", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("model_id", sa.Uuid(), nullable=False), + sa.Column("version", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["model_id"], ["tax_benefit_models.id"]), + ) + + # Datasets (h5 files in storage) + op.create_table( + "datasets", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("filepath", sa.String(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("is_output_dataset", sa.Boolean(), nullable=False, default=False), + sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["tax_benefit_model_id"], ["tax_benefit_models.id"]), + ) + + # ======================================================================== + # TIER 3: Tables depending on tier 2 + # ======================================================================== + + # Parameters (tax-benefit system parameters) + op.create_table( + "parameters", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("label", sa.String(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("data_type", sa.String(), nullable=True), + sa.Column("unit", sa.String(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], ["tax_benefit_model_versions.id"] + ), + ) + + # Variables (tax-benefit system variables) + op.create_table( + "variables", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("data_type", sa.String(), nullable=True), + sa.Column("possible_values", sa.JSON(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], ["tax_benefit_model_versions.id"] + ), + ) + + # Dataset versions + op.create_table( + "dataset_versions", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=False), + sa.Column("dataset_id", sa.Uuid(), nullable=False), + sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"]), + sa.ForeignKeyConstraint(["tax_benefit_model_id"], ["tax_benefit_models.id"]), + ) + + # ======================================================================== + # TIER 4: Tables depending on tier 3 + # ======================================================================== + + # Parameter values (policy/dynamic parameter modifications) + op.create_table( + "parameter_values", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("parameter_id", sa.Uuid(), nullable=False), + sa.Column("value_json", sa.JSON(), nullable=True), + sa.Column("start_date", sa.DateTime(timezone=True), nullable=False), + sa.Column("end_date", sa.DateTime(timezone=True), nullable=True), + sa.Column("policy_id", sa.Uuid(), nullable=True), + sa.Column("dynamic_id", sa.Uuid(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["parameter_id"], ["parameters.id"]), + sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), + sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), + ) + + # Simulations (economy calculations) - NOTE: No household support yet + op.create_table( + "simulations", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("dataset_id", sa.Uuid(), nullable=False), # Required in main + sa.Column("policy_id", sa.Uuid(), nullable=True), + sa.Column("dynamic_id", sa.Uuid(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column("output_dataset_id", sa.Uuid(), nullable=True), + sa.Column("status", sa.String(), nullable=False, default="pending"), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"]), + sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), + sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], ["tax_benefit_model_versions.id"] + ), + sa.ForeignKeyConstraint(["output_dataset_id"], ["datasets.id"]), + ) + + # Household jobs (async household calculations) - legacy approach + op.create_table( + "household_jobs", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("tax_benefit_model_name", sa.String(), nullable=False), + sa.Column("request_data", sa.JSON(), nullable=False), + sa.Column("policy_id", sa.Uuid(), nullable=True), + sa.Column("dynamic_id", sa.Uuid(), nullable=True), + sa.Column("status", sa.String(), nullable=False, default="pending"), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("result", sa.JSON(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), + sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), + ) + + # ======================================================================== + # TIER 5: Tables depending on simulations + # ======================================================================== + + # Reports (analysis reports) - NOTE: No report_type yet + op.create_table( + "reports", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("label", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("user_id", sa.Uuid(), nullable=True), + sa.Column("markdown", sa.Text(), nullable=True), + sa.Column("parent_report_id", sa.Uuid(), nullable=True), + sa.Column("status", sa.String(), nullable=False, default="pending"), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=True), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.ForeignKeyConstraint(["parent_report_id"], ["reports.id"]), + sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), + ) + + # Aggregates (single-simulation aggregate outputs) + op.create_table( + "aggregates", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("simulation_id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=True), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("variable", sa.String(), nullable=False), + sa.Column("aggregate_type", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=True), + sa.Column("filter_config", sa.JSON(), nullable=False, default={}), + sa.Column("status", sa.String(), nullable=False, default="pending"), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("result", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), + ) + + # Change aggregates (baseline vs reform comparison) + op.create_table( + "change_aggregates", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=True), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("variable", sa.String(), nullable=False), + sa.Column("aggregate_type", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=True), + sa.Column("filter_config", sa.JSON(), nullable=False, default={}), + sa.Column("change_geq", sa.Float(), nullable=True), + sa.Column("change_leq", sa.Float(), nullable=True), + sa.Column("status", sa.String(), nullable=False, default="pending"), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("result", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), + ) + + # Decile impacts + op.create_table( + "decile_impacts", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("income_variable", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=True), + sa.Column("decile", sa.Integer(), nullable=False), + sa.Column("quantiles", sa.Integer(), nullable=False, default=10), + sa.Column("baseline_mean", sa.Float(), nullable=True), + sa.Column("reform_mean", sa.Float(), nullable=True), + sa.Column("absolute_change", sa.Float(), nullable=True), + sa.Column("relative_change", sa.Float(), nullable=True), + sa.Column("count_better_off", sa.Float(), nullable=True), + sa.Column("count_worse_off", sa.Float(), nullable=True), + sa.Column("count_no_change", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), + ) + + # Program statistics + op.create_table( + "program_statistics", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("program_name", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=False), + sa.Column("is_tax", sa.Boolean(), nullable=False, default=False), + sa.Column("baseline_total", sa.Float(), nullable=True), + sa.Column("reform_total", sa.Float(), nullable=True), + sa.Column("change", sa.Float(), nullable=True), + sa.Column("baseline_count", sa.Float(), nullable=True), + sa.Column("reform_count", sa.Float(), nullable=True), + sa.Column("winners", sa.Float(), nullable=True), + sa.Column("losers", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), + ) + + # Poverty + op.create_table( + "poverty", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("poverty_type", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=False, default="person"), + sa.Column("filter_variable", sa.String(), nullable=True), + sa.Column("headcount", sa.Float(), nullable=True), + sa.Column("total_population", sa.Float(), nullable=True), + sa.Column("rate", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["simulation_id"], ["simulations.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"], ondelete="CASCADE"), + ) + op.create_index("idx_poverty_simulation_id", "poverty", ["simulation_id"]) + op.create_index("idx_poverty_report_id", "poverty", ["report_id"]) + + # Inequality + op.create_table( + "inequality", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("income_variable", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=False, default="household"), + sa.Column("gini", sa.Float(), nullable=True), + sa.Column("top_10_share", sa.Float(), nullable=True), + sa.Column("top_1_share", sa.Float(), nullable=True), + sa.Column("bottom_50_share", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["simulation_id"], ["simulations.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"], ondelete="CASCADE"), + ) + op.create_index("idx_inequality_simulation_id", "inequality", ["simulation_id"]) + op.create_index("idx_inequality_report_id", "inequality", ["report_id"]) + + +def downgrade() -> None: + """Drop all tables in reverse order.""" + # Tier 5 + op.drop_index("idx_inequality_report_id", "inequality") + op.drop_index("idx_inequality_simulation_id", "inequality") + op.drop_table("inequality") + op.drop_index("idx_poverty_report_id", "poverty") + op.drop_index("idx_poverty_simulation_id", "poverty") + op.drop_table("poverty") + op.drop_table("program_statistics") + op.drop_table("decile_impacts") + op.drop_table("change_aggregates") + op.drop_table("aggregates") + op.drop_table("reports") + + # Tier 4 + op.drop_table("household_jobs") + op.drop_table("simulations") + op.drop_table("parameter_values") + + # Tier 3 + op.drop_table("dataset_versions") + op.drop_table("variables") + op.drop_table("parameters") + + # Tier 2 + op.drop_table("datasets") + op.drop_table("tax_benefit_model_versions") + + # Tier 1 + op.drop_table("dynamics") + op.drop_table("policies") + op.drop_index("ix_users_email", "users") + op.drop_table("users") + op.drop_table("tax_benefit_models") diff --git a/alembic/versions/20260204_0002_add_household_support.py b/alembic/versions/20260204_0002_add_household_support.py new file mode 100644 index 0000000..beb00a0 --- /dev/null +++ b/alembic/versions/20260204_0002_add_household_support.py @@ -0,0 +1,170 @@ +"""Add household CRUD and impact analysis support + +Revision ID: 0002_household +Revises: 0001_initial +Create Date: 2026-02-04 + +This migration adds support for: +- Storing household definitions (households table) +- User-household associations for saved households +- Household-based simulations (adds household_id to simulations) +- Household impact reports (adds report_type to reports) +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0002_household" +down_revision: Union[str, Sequence[str], None] = "0001_initial" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add household support.""" + # ======================================================================== + # NEW TABLES + # ======================================================================== + + # Households (stored household definitions) + op.create_table( + "households", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("tax_benefit_model_name", sa.String(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("label", sa.String(), nullable=True), + sa.Column("household_data", sa.JSON(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "idx_households_model_name", "households", ["tax_benefit_model_name"] + ) + op.create_index("idx_households_year", "households", ["year"]) + + # User-household associations (many-to-many for saved households) + op.create_table( + "user_household_associations", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("household_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["household_id"], ["households.id"], ondelete="CASCADE"), + sa.UniqueConstraint("user_id", "household_id"), + ) + op.create_index( + "idx_user_household_user", "user_household_associations", ["user_id"] + ) + op.create_index( + "idx_user_household_household", "user_household_associations", ["household_id"] + ) + + # ======================================================================== + # MODIFY SIMULATIONS TABLE + # ======================================================================== + + # Add simulation_type column (economy vs household) + op.add_column( + "simulations", + sa.Column( + "simulation_type", + sa.String(), + nullable=False, + server_default="economy", + ), + ) + + # Add household_id column (for household simulations) + op.add_column( + "simulations", + sa.Column("household_id", sa.Uuid(), nullable=True), + ) + op.create_foreign_key( + "fk_simulations_household_id", + "simulations", + "households", + ["household_id"], + ["id"], + ) + + # Add household_result column (stores household calculation results) + op.add_column( + "simulations", + sa.Column("household_result", sa.JSON(), nullable=True), + ) + + # Make dataset_id nullable (household simulations don't need a dataset) + op.alter_column( + "simulations", + "dataset_id", + existing_type=sa.Uuid(), + nullable=True, + ) + + # ======================================================================== + # MODIFY REPORTS TABLE + # ======================================================================== + + # Add report_type column (economy_comparison, household_impact, etc.) + op.add_column( + "reports", + sa.Column("report_type", sa.String(), nullable=True), + ) + + +def downgrade() -> None: + """Remove household support.""" + # ======================================================================== + # REVERT REPORTS TABLE + # ======================================================================== + op.drop_column("reports", "report_type") + + # ======================================================================== + # REVERT SIMULATIONS TABLE + # ======================================================================== + + # Make dataset_id required again + op.alter_column( + "simulations", + "dataset_id", + existing_type=sa.Uuid(), + nullable=False, + ) + + # Remove household columns + op.drop_column("simulations", "household_result") + op.drop_constraint("fk_simulations_household_id", "simulations", type_="foreignkey") + op.drop_column("simulations", "household_id") + op.drop_column("simulations", "simulation_type") + + # ======================================================================== + # DROP NEW TABLES + # ======================================================================== + op.drop_index("idx_user_household_household", "user_household_associations") + op.drop_index("idx_user_household_user", "user_household_associations") + op.drop_table("user_household_associations") + + op.drop_index("idx_households_year", "households") + op.drop_index("idx_households_model_name", "households") + op.drop_table("households") diff --git a/alembic/versions/20260204_0003_add_parameter_values_indexes.py b/alembic/versions/20260204_0003_add_parameter_values_indexes.py new file mode 100644 index 0000000..53518cf --- /dev/null +++ b/alembic/versions/20260204_0003_add_parameter_values_indexes.py @@ -0,0 +1,52 @@ +"""Add parameter_values indexes + +Revision ID: 0003_param_idx +Revises: 0002_household +Create Date: 2026-02-04 02:20:00.000000 + +This migration adds performance indexes to the parameter_values table +for optimizing common query patterns. +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0003_param_idx" +down_revision: Union[str, Sequence[str], None] = "0002_household" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add performance indexes to parameter_values.""" + # Composite index for the most common query pattern (filtering by both) + op.create_index( + "idx_parameter_values_parameter_policy", + "parameter_values", + ["parameter_id", "policy_id"], + ) + + # Single index on policy_id for filtering by policy alone + op.create_index( + "idx_parameter_values_policy", + "parameter_values", + ["policy_id"], + ) + + # Partial index for baseline values (policy_id IS NULL) + # This optimizes the common "get current law values" query + op.create_index( + "idx_parameter_values_baseline", + "parameter_values", + ["parameter_id"], + postgresql_where="policy_id IS NULL", + ) + + +def downgrade() -> None: + """Remove parameter_values indexes.""" + op.drop_index("idx_parameter_values_baseline", "parameter_values") + op.drop_index("idx_parameter_values_policy", "parameter_values") + op.drop_index("idx_parameter_values_parameter_policy", "parameter_values") diff --git a/pyproject.toml b/pyproject.toml index 27eb310..1fe9093 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "fastapi-mcp>=0.4.0", "modal>=0.68.0", "anthropic>=0.40.0", + "alembic>=1.13.0", ] [project.optional-dependencies] diff --git a/scripts/init.py b/scripts/init.py index cf7a04a..3aa925b 100644 --- a/scripts/init.py +++ b/scripts/init.py @@ -1,12 +1,19 @@ -"""Initialise Supabase: reset database, recreate tables, buckets, and permissions. +"""Initialise Supabase database with tables, buckets, and permissions. -This script performs a complete reset of the Supabase instance: -1. Drops and recreates the public schema (all tables) -2. Deletes and recreates the storage bucket -3. Creates all tables from SQLModel definitions -4. Applies RLS policies and storage permissions +This script can run in two modes: +1. Init mode (default): Creates tables via Alembic, applies RLS policies +2. Reset mode (--reset): Drops everything and recreates from scratch (DESTRUCTIVE) + +Usage: + uv run python scripts/init.py # Safe init (creates if not exists) + uv run python scripts/init.py --reset # Destructive reset (drops everything) + +For local development after `supabase start`, use init mode. +For production, use init mode to ensure tables and policies exist. +Reset mode should only be used when you need a completely fresh database. """ +import subprocess import sys from pathlib import Path @@ -14,16 +21,14 @@ from rich.console import Console from rich.panel import Panel -from sqlmodel import SQLModel, create_engine +from sqlmodel import create_engine -# Import all models to register them with SQLModel.metadata -from policyengine_api import models # noqa: F401 from policyengine_api.config.settings import settings from policyengine_api.services.storage import get_service_role_client console = Console() -MIGRATIONS_DIR = Path(__file__).parent.parent / "supabase" / "migrations" +PROJECT_ROOT = Path(__file__).parent.parent def reset_storage_bucket(): @@ -57,30 +62,61 @@ def reset_storage_bucket(): console.print(f"[yellow]⚠ Warning with storage bucket: {e}[/yellow]") +def ensure_storage_bucket(): + """Ensure storage bucket exists (non-destructive).""" + console.print("[bold blue]Ensuring storage bucket exists...") + + try: + supabase = get_service_role_client() + bucket_name = settings.storage_bucket + + # Try to get bucket info + try: + supabase.storage.get_bucket(bucket_name) + console.print(f"[green]✓[/green] Bucket '{bucket_name}' exists") + except Exception: + # Bucket doesn't exist, create it + supabase.storage.create_bucket(bucket_name, options={"public": True}) + console.print(f"[green]✓[/green] Created bucket '{bucket_name}'") + + except Exception as e: + console.print(f"[yellow]⚠ Warning with storage bucket: {e}[/yellow]") + + def reset_database(): - """Drop and recreate all tables.""" - console.print("[bold blue]Resetting database...") + """Drop and recreate the public schema (DESTRUCTIVE).""" + console.print("[bold red]Dropping database schema...") engine = create_engine(settings.database_url, echo=False) - # Drop and recreate public schema - console.print(" Dropping public schema...") with engine.begin() as conn: conn.exec_driver_sql("DROP SCHEMA public CASCADE") conn.exec_driver_sql("CREATE SCHEMA public") conn.exec_driver_sql("GRANT ALL ON SCHEMA public TO postgres") conn.exec_driver_sql("GRANT ALL ON SCHEMA public TO public") - # Create all tables from SQLModel - console.print(" Creating tables...") - SQLModel.metadata.create_all(engine) + console.print("[green]✓[/green] Schema dropped and recreated") + return engine - tables = list(SQLModel.metadata.tables.keys()) - console.print(f"[green]✓[/green] Created {len(tables)} tables:") - for table in sorted(tables): - console.print(f" {table}") - return engine +def run_alembic_migrations(): + """Run Alembic migrations to create/update tables.""" + console.print("[bold blue]Running Alembic migrations...") + + result = subprocess.run( + ["uv", "run", "alembic", "upgrade", "head"], + cwd=PROJECT_ROOT, + capture_output=True, + text=True, + ) + + if result.returncode != 0: + console.print(f"[red]✗ Alembic migration failed:[/red]") + console.print(result.stderr) + raise RuntimeError("Alembic migration failed") + + console.print("[green]✓[/green] Alembic migrations complete") + console.print(result.stdout) def apply_storage_policies(engine): @@ -158,6 +194,10 @@ def apply_rls_policies(engine): "parameter_values", "users", "household_jobs", + "households", + "user_household_associations", + "poverty", + "inequality", ] # Read-only tables (public can read, only service role can write) @@ -178,6 +218,7 @@ def apply_rls_policies(engine): "dynamics", "reports", "household_jobs", + "households", ] # Read-only results tables @@ -186,6 +227,8 @@ def apply_rls_policies(engine): "change_aggregates", "decile_impacts", "program_statistics", + "poverty", + "inequality", ] sql_parts = [] @@ -230,6 +273,13 @@ def apply_rls_policies(engine): FOR SELECT TO anon, authenticated USING (true); """) + # User-household associations need special handling + sql_parts.append(""" + DROP POLICY IF EXISTS "Users can manage own associations" ON user_household_associations; + CREATE POLICY "Users can manage own associations" ON user_household_associations + FOR ALL TO anon, authenticated USING (true) WITH CHECK (true); + """) + sql = "\n".join(sql_parts) conn = engine.raw_connection() @@ -246,30 +296,53 @@ def apply_rls_policies(engine): def main(): - """Run full Supabase initialisation.""" - console.print( - Panel.fit( - "[bold red]⚠ WARNING: This will DELETE ALL DATA[/bold red]\n" - "This script resets the entire Supabase instance.", - title="Supabase init", + """Run Supabase initialisation.""" + reset_mode = "--reset" in sys.argv + + if reset_mode: + console.print( + Panel.fit( + "[bold red]⚠ WARNING: This will DELETE ALL DATA[/bold red]\n" + "This script will reset the entire Supabase instance.", + title="Supabase RESET", + ) ) - ) - # Confirm unless running non-interactively - if sys.stdin.isatty(): - response = console.input("\nType 'yes' to continue: ") - if response.lower() != "yes": - console.print("[yellow]Aborted[/yellow]") - return + # Confirm unless running non-interactively + if sys.stdin.isatty(): + response = console.input("\nType 'yes' to continue: ") + if response.lower() != "yes": + console.print("[yellow]Aborted[/yellow]") + return + + console.print() + + # Reset storage bucket + reset_storage_bucket() + console.print() + + # Drop database schema + engine = reset_database() + console.print() + else: + console.print( + Panel.fit( + "[bold blue]Initialising Supabase[/bold blue]\n" + "This will create tables if they don't exist (safe/idempotent).\n" + "Use [cyan]--reset[/cyan] flag to drop and recreate everything.", + title="Supabase init", + ) + ) + console.print() - console.print() + # Ensure storage bucket exists + ensure_storage_bucket() + console.print() - # Reset storage bucket - reset_storage_bucket() - console.print() + engine = create_engine(settings.database_url, echo=False) - # Reset database and create tables - engine = reset_database() + # Run Alembic migrations to create/update tables + run_alembic_migrations() console.print() # Apply storage policies diff --git a/scripts/seed.py b/scripts/seed.py index f3fbfa8..4274528 100644 --- a/scripts/seed.py +++ b/scripts/seed.py @@ -363,7 +363,7 @@ def seed_model(model_version, session, lite: bool = False) -> TaxBenefitModelVer return db_version -def seed_datasets(session, lite: bool = False): +def seed_datasets(session, lite: bool = False, skip_uk_datasets: bool = False): """Seed datasets and upload to S3.""" with logfire.span("seed_datasets"): mode_str = " (lite mode - 2026 only)" if lite else "" @@ -383,60 +383,64 @@ def seed_datasets(session, lite: bool = False): ) return - # UK datasets - console.print(" Creating UK datasets...") data_folder = str(Path(__file__).parent.parent / "data") - uk_datasets = ensure_uk_datasets(data_folder=data_folder) - - # In lite mode, only upload FRS 2026 - if lite: - uk_datasets = { - k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k - } - console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") + # UK datasets uk_created = 0 uk_skipped = 0 - with logfire.span("seed_uk_datasets", count=len(uk_datasets)): - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("UK datasets", total=len(uk_datasets)) - for _, pe_dataset in uk_datasets.items(): - progress.update(task, description=f"UK: {pe_dataset.name}") - - # Check if dataset already exists - existing = session.exec( - select(Dataset).where(Dataset.name == pe_dataset.name) - ).first() - - if existing: - uk_skipped += 1 + if skip_uk_datasets: + console.print(" [yellow]Skipping UK datasets (--skip-uk-datasets)[/yellow]") + else: + console.print(" Creating UK datasets...") + uk_datasets = ensure_uk_datasets(data_folder=data_folder) + + # In lite mode, only upload FRS 2026 + if lite: + uk_datasets = { + k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k + } + console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") + + with logfire.span("seed_uk_datasets", count=len(uk_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("UK datasets", total=len(uk_datasets)) + for _, pe_dataset in uk_datasets.items(): + progress.update(task, description=f"UK: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + uk_skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=uk_model.id, + ) + session.add(db_dataset) + session.commit() + uk_created += 1 progress.advance(task) - continue - - # Upload to S3 - object_name = upload_dataset_for_seeding(pe_dataset.filepath) - - # Create database record - db_dataset = Dataset( - name=pe_dataset.name, - description=pe_dataset.description, - filepath=object_name, - year=pe_dataset.year, - tax_benefit_model_id=uk_model.id, - ) - session.add(db_dataset) - session.commit() - uk_created += 1 - progress.advance(task) - console.print( - f" [green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped" - ) + console.print( + f" [green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped" + ) # US datasets console.print(" Creating US datasets...") @@ -622,6 +626,11 @@ def main(): action="store_true", help="Lite mode: skip US state parameters, only seed FRS 2026 and CPS 2026 datasets", ) + parser.add_argument( + "--skip-uk-datasets", + action="store_true", + help="Skip UK datasets (useful when HuggingFace token is not available)", + ) args = parser.parse_args() with logfire.span("database_seeding"): @@ -638,7 +647,7 @@ def main(): console.print(f"[green]✓[/green] US model seeded: {us_version.id}\n") # Seed datasets - seed_datasets(session, lite=args.lite) + seed_datasets(session, lite=args.lite, skip_uk_datasets=args.skip_uk_datasets) # Seed example policies seed_example_policies(session) diff --git a/scripts/seed_common.py b/scripts/seed_common.py new file mode 100644 index 0000000..f6d7ab6 --- /dev/null +++ b/scripts/seed_common.py @@ -0,0 +1,359 @@ +"""Shared utilities for seed scripts.""" + +import io +import json +import logging +import math +import sys +import warnings +from datetime import datetime, timezone +from pathlib import Path +from uuid import uuid4 + +import logfire +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import Session, create_engine + +# Disable all SQLAlchemy and database logging BEFORE any imports +logging.basicConfig(level=logging.ERROR) +logging.getLogger("sqlalchemy").setLevel(logging.ERROR) +warnings.filterwarnings("ignore") + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from policyengine_api.config.settings import settings # noqa: E402 + +# Configure logfire +if settings.logfire_token: + logfire.configure( + token=settings.logfire_token, + environment=settings.logfire_environment, + console=False, + ) + +console = Console() + + +def get_session(): + """Get database session with logging disabled.""" + engine = create_engine(settings.database_url, echo=False) + return Session(engine) + + +def bulk_insert(session, table: str, columns: list[str], rows: list[dict]): + """Fast bulk insert using PostgreSQL COPY via StringIO.""" + if not rows: + return + + # Get raw psycopg2 connection + connection = session.connection() + raw_conn = connection.connection.dbapi_connection + cursor = raw_conn.cursor() + + # Build CSV-like data in memory + output = io.StringIO() + for row in rows: + values = [] + for col in columns: + val = row[col] + if val is None: + values.append("\\N") + elif isinstance(val, str): + # Escape special characters for COPY + val = ( + val.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n") + ) + values.append(val) + else: + values.append(str(val)) + output.write("\t".join(values) + "\n") + + output.seek(0) + + # COPY is the fastest way to bulk load PostgreSQL + cursor.copy_from(output, table, columns=columns, null="\\N") + session.commit() + + +def seed_model(model_version, session, lite: bool = False): + """Seed a tax-benefit model with its variables and parameters. + + Returns the TaxBenefitModelVersion that was created or found. + """ + from policyengine_api.models import ( + TaxBenefitModel, + TaxBenefitModelVersion, + ) + from sqlmodel import select + + with logfire.span( + "seed_model", + model=model_version.model.id, + version=model_version.version, + ): + # Create or get the model + console.print(f"[bold blue]Seeding {model_version.model.id}...") + + existing_model = session.exec( + select(TaxBenefitModel).where( + TaxBenefitModel.name == model_version.model.id + ) + ).first() + + if existing_model: + db_model = existing_model + console.print(f" Using existing model: {db_model.id}") + else: + db_model = TaxBenefitModel( + name=model_version.model.id, + description=model_version.model.description, + ) + session.add(db_model) + session.commit() + session.refresh(db_model) + console.print(f" Created model: {db_model.id}") + + # Create model version + existing_version = session.exec( + select(TaxBenefitModelVersion).where( + TaxBenefitModelVersion.model_id == db_model.id, + TaxBenefitModelVersion.version == model_version.version, + ) + ).first() + + if existing_version: + console.print( + f" Model version {model_version.version} already exists, skipping" + ) + return existing_version + + db_version = TaxBenefitModelVersion( + model_id=db_model.id, + version=model_version.version, + description=f"Version {model_version.version}", + ) + session.add(db_version) + session.commit() + session.refresh(db_version) + console.print(f" Created version: {db_version.version}") + + # Add variables + with logfire.span("add_variables", count=len(model_version.variables)): + var_rows = [] + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(model_version.variables)} variables", + total=len(model_version.variables), + ) + for var in model_version.variables: + var_rows.append( + { + "id": uuid4(), + "name": var.name, + "entity": var.entity, + "description": var.description or "", + "data_type": var.data_type.__name__ + if hasattr(var.data_type, "__name__") + else str(var.data_type), + "possible_values": None, + "tax_benefit_model_version_id": db_version.id, + "created_at": datetime.now(timezone.utc), + } + ) + progress.advance(task) + + console.print(f" Inserting {len(var_rows)} variables...") + bulk_insert( + session, + "variables", + [ + "id", + "name", + "entity", + "description", + "data_type", + "possible_values", + "tax_benefit_model_version_id", + "created_at", + ], + var_rows, + ) + + console.print( + f" [green]✓[/green] Added {len(model_version.variables)} variables" + ) + + # Add parameters - deduplicate by name (keep first occurrence) + # + # WHY DEDUPLICATION IS NEEDED: + # The policyengine package can provide multiple parameter entries with the same + # name. This happens because parameters can have multiple bracket entries or + # state-specific variants that share the same base name. We keep only the first + # occurrence to avoid database unique constraint violations and reduce redundancy. + # + # NOTE: We do NOT filter by label. Parameters without labels (bracket params, + # breakdown params) are still valid and needed for policy analysis. + # + # In lite mode, exclude US state parameters (gov.states.*) + seen_names = set() + parameters_to_add = [] + skipped_state_params = 0 + skipped_duplicate = 0 + + for p in model_version.parameters: + if p.name in seen_names: + skipped_duplicate += 1 + continue + # In lite mode, skip state-level parameters for faster seeding + if lite and p.name.startswith("gov.states."): + skipped_state_params += 1 + continue + parameters_to_add.append(p) + seen_names.add(p.name) + + console.print(f" Parameter filtering:") + console.print(f" - Total from source: {len(model_version.parameters)}") + console.print(f" - Skipped (duplicate name): {skipped_duplicate}") + if lite and skipped_state_params > 0: + console.print(f" - Skipped (state params, lite mode): {skipped_state_params}") + console.print(f" - To add: {len(parameters_to_add)}") + + with logfire.span("add_parameters", count=len(parameters_to_add)): + # Build list of parameter dicts for bulk insert + param_rows = [] + param_names = [] # Track (pe_id, name, generated_uuid) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(parameters_to_add)} parameters", + total=len(parameters_to_add), + ) + for param in parameters_to_add: + param_uuid = uuid4() + param_rows.append( + { + "id": param_uuid, + "name": param.name, + "label": param.label if hasattr(param, "label") else None, + "description": param.description or "", + "data_type": param.data_type.__name__ + if hasattr(param.data_type, "__name__") + else str(param.data_type), + "unit": param.unit, + "tax_benefit_model_version_id": db_version.id, + "created_at": datetime.now(timezone.utc), + } + ) + param_names.append((param.id, param.name, param_uuid)) + progress.advance(task) + + console.print(f" Inserting {len(param_rows)} parameters...") + bulk_insert( + session, + "parameters", + [ + "id", + "name", + "label", + "description", + "data_type", + "unit", + "tax_benefit_model_version_id", + "created_at", + ], + param_rows, + ) + + # Build param_id_map from pre-generated UUIDs + param_id_map = {pe_id: db_uuid for pe_id, name, db_uuid in param_names} + + console.print( + f" [green]✓[/green] Added {len(parameters_to_add)} parameters" + ) + + # Add parameter values + # Filter to only include values for parameters we added + parameter_values_to_add = [ + pv + for pv in model_version.parameter_values + if pv.parameter.id in param_id_map + ] + console.print(f" Found {len(parameter_values_to_add)} parameter values to add") + + with logfire.span("add_parameter_values", count=len(parameter_values_to_add)): + pv_rows = [] + skipped = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(parameter_values_to_add)} parameter values", + total=len(parameter_values_to_add), + ) + for pv in parameter_values_to_add: + # Handle Infinity values - skip them as they can't be stored in JSON + if isinstance(pv.value, float) and ( + math.isinf(pv.value) or math.isnan(pv.value) + ): + skipped += 1 + progress.advance(task) + continue + + # Source data has dates swapped (start > end), fix ordering + # Only swap if both dates are set, otherwise keep original + if pv.start_date and pv.end_date: + start = pv.end_date # Swap: source end is our start + end = pv.start_date # Swap: source start is our end + else: + start = pv.start_date + end = pv.end_date + pv_rows.append( + { + "id": uuid4(), + "parameter_id": param_id_map[pv.parameter.id], + "value_json": json.dumps(pv.value), + "start_date": start, + "end_date": end, + "policy_id": None, + "dynamic_id": None, + "created_at": datetime.now(timezone.utc), + } + ) + progress.advance(task) + + console.print(f" Inserting {len(pv_rows)} parameter values...") + bulk_insert( + session, + "parameter_values", + [ + "id", + "parameter_id", + "value_json", + "start_date", + "end_date", + "policy_id", + "dynamic_id", + "created_at", + ], + pv_rows, + ) + + console.print( + f" [green]✓[/green] Added {len(pv_rows)} parameter values" + + (f" (skipped {skipped} invalid)" if skipped else "") + ) + + return db_version diff --git a/scripts/seed_nevada.py b/scripts/seed_nevada.py new file mode 100644 index 0000000..0af2cb4 --- /dev/null +++ b/scripts/seed_nevada.py @@ -0,0 +1,128 @@ +"""Seed Nevada datasets into local Supabase. + +This script seeds pre-created Nevada state and congressional district datasets +into the local Supabase database for testing purposes. + +Usage: + uv run python scripts/seed_nevada.py +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from rich.console import Console +from sqlmodel import Session, create_engine, select + +from policyengine_api.config.settings import settings +from policyengine_api.models import Dataset, TaxBenefitModel +from policyengine_api.services.storage import upload_dataset_for_seeding + +console = Console() + +# Nevada datasets location +NEVADA_DATA_DIR = Path(__file__).parent.parent / "test_data" / "nevada_datasets" + + +def main(): + """Seed Nevada datasets.""" + console.print("[bold blue]Seeding Nevada datasets for testing...") + + engine = create_engine(settings.database_url, echo=False) + + with Session(engine) as session: + # Get or create US model + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print(" Creating US tax-benefit model...") + us_model = TaxBenefitModel( + name="policyengine-us", + description="US tax-benefit system model", + ) + session.add(us_model) + session.commit() + session.refresh(us_model) + console.print(" [green]✓[/green] Created policyengine-us model") + + # Seed state datasets + states_dir = NEVADA_DATA_DIR / "states" + if states_dir.exists(): + console.print("\n [bold]Nevada State Datasets:[/bold]") + for h5_file in sorted(states_dir.glob("*.h5")): + name = h5_file.stem # e.g., "NV_year_2024" + year = int(name.split("_")[-1]) + + # Check if already exists + existing = session.exec( + select(Dataset).where(Dataset.name == name) + ).first() + + if existing: + console.print(f" [yellow]⏭[/yellow] {name} (already exists)") + continue + + # Upload to storage + console.print(f" Uploading {name}...", end=" ") + try: + object_name = upload_dataset_for_seeding(str(h5_file)) + + # Create database record + db_dataset = Dataset( + name=name, + description=f"Nevada state dataset for year {year}", + filepath=object_name, + year=year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + console.print("[green]✓[/green]") + except Exception as e: + console.print(f"[red]✗ {e}[/red]") + + # Seed district datasets + districts_dir = NEVADA_DATA_DIR / "districts" + if districts_dir.exists(): + console.print("\n [bold]Nevada Congressional District Datasets:[/bold]") + for h5_file in sorted(districts_dir.glob("*.h5")): + name = h5_file.stem # e.g., "NV-01_year_2024" + year = int(name.split("_")[-1]) + district = name.split("_")[0] # e.g., "NV-01" + + # Check if already exists + existing = session.exec( + select(Dataset).where(Dataset.name == name) + ).first() + + if existing: + console.print(f" [yellow]⏭[/yellow] {name} (already exists)") + continue + + # Upload to storage + console.print(f" Uploading {name}...", end=" ") + try: + object_name = upload_dataset_for_seeding(str(h5_file)) + + # Create database record + db_dataset = Dataset( + name=name, + description=f"{district} congressional district dataset for year {year}", + filepath=object_name, + year=year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + console.print("[green]✓[/green]") + except Exception as e: + console.print(f"[red]✗ {e}[/red]") + + console.print("\n[bold green]✓ Nevada datasets seeded successfully![/bold green]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_policies.py b/scripts/seed_policies.py new file mode 100644 index 0000000..e57b964 --- /dev/null +++ b/scripts/seed_policies.py @@ -0,0 +1,143 @@ +"""Seed example policy reforms for UK and US.""" + +import time +from datetime import datetime, timezone + +import logfire +from sqlmodel import select + +from seed_common import console, get_session + + +def main(): + from policyengine_api.models import ( + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, + ) + + console.print("[bold green]Seeding example policies...[/bold green]\n") + + start = time.time() + with get_session() as session: + with logfire.span("seed_example_policies"): + # Get model versions + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not uk_model or not us_model: + console.print( + "[red]Error: UK or US model not found. Run seed_*_model.py first.[/red]" + ) + return + + uk_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == uk_model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + + us_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == us_model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + + # UK example policy: raise basic rate to 22p + uk_policy_name = "UK basic rate 22p" + existing_uk_policy = session.exec( + select(Policy).where(Policy.name == uk_policy_name) + ).first() + + if existing_uk_policy: + console.print(f" Policy '{uk_policy_name}' already exists, skipping") + else: + # Find the basic rate parameter + uk_basic_rate_param = session.exec( + select(Parameter).where( + Parameter.name == "gov.hmrc.income_tax.rates.uk[0].rate", + Parameter.tax_benefit_model_version_id == uk_version.id, + ) + ).first() + + if uk_basic_rate_param: + uk_policy = Policy( + name=uk_policy_name, + description="Raise the UK income tax basic rate from 20p to 22p", + ) + session.add(uk_policy) + session.commit() + session.refresh(uk_policy) + + # Add parameter value (22% = 0.22) + uk_param_value = ParameterValue( + parameter_id=uk_basic_rate_param.id, + value_json={"value": 0.22}, + start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_date=None, + policy_id=uk_policy.id, + ) + session.add(uk_param_value) + session.commit() + console.print(f" [green]✓[/green] Created UK policy: {uk_policy_name}") + else: + console.print( + " [yellow]Warning: UK basic rate parameter not found[/yellow]" + ) + + # US example policy: raise first bracket rate to 12% + us_policy_name = "US 12% lowest bracket" + existing_us_policy = session.exec( + select(Policy).where(Policy.name == us_policy_name) + ).first() + + if existing_us_policy: + console.print(f" Policy '{us_policy_name}' already exists, skipping") + else: + # Find the first bracket rate parameter + us_first_bracket_param = session.exec( + select(Parameter).where( + Parameter.name == "gov.irs.income.bracket.rates.1", + Parameter.tax_benefit_model_version_id == us_version.id, + ) + ).first() + + if us_first_bracket_param: + us_policy = Policy( + name=us_policy_name, + description="Raise US federal income tax lowest bracket to 12%", + ) + session.add(us_policy) + session.commit() + session.refresh(us_policy) + + # Add parameter value (12% = 0.12) + us_param_value = ParameterValue( + parameter_id=us_first_bracket_param.id, + value_json={"value": 0.12}, + start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_date=None, + policy_id=us_policy.id, + ) + session.add(us_param_value) + session.commit() + console.print(f" [green]✓[/green] Created US policy: {us_policy_name}") + else: + console.print( + " [yellow]Warning: US first bracket parameter not found[/yellow]" + ) + + console.print("[green]✓[/green] Example policies seeded") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_uk_datasets.py b/scripts/seed_uk_datasets.py new file mode 100644 index 0000000..1754454 --- /dev/null +++ b/scripts/seed_uk_datasets.py @@ -0,0 +1,113 @@ +"""Seed UK datasets (FRS) and upload to S3. + +NOTE: Requires HUGGING_FACE_TOKEN environment variable to be set, +as UK FRS datasets are hosted on a private HuggingFace repository. +""" + +import argparse +import time +from pathlib import Path + +import logfire +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import select + +from seed_common import console, get_session + + +def main(): + parser = argparse.ArgumentParser(description="Seed UK datasets") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: only seed FRS 2026", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.uk.datasets import ( + ensure_datasets as ensure_uk_datasets, + ) + + from policyengine_api.models import Dataset, TaxBenefitModel + from policyengine_api.services.storage import upload_dataset_for_seeding + + console.print("[bold green]Seeding UK datasets...[/bold green]\n") + console.print("[yellow]Note: Requires HUGGING_FACE_TOKEN environment variable[/yellow]\n") + + start = time.time() + with get_session() as session: + # Get UK model + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + + if not uk_model: + console.print("[red]Error: UK model not found. Run seed_uk_model.py first.[/red]") + return + + data_folder = str(Path(__file__).parent.parent / "data") + console.print(f" Data folder: {data_folder}") + + # Get datasets + console.print(" Loading UK datasets from policyengine package...") + ds_start = time.time() + uk_datasets = ensure_uk_datasets(data_folder=data_folder) + console.print(f" Loaded {len(uk_datasets)} datasets in {time.time() - ds_start:.1f}s") + + # In lite mode, only upload FRS 2026 + if args.lite: + uk_datasets = { + k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k + } + console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") + + created = 0 + skipped = 0 + + with logfire.span("seed_uk_datasets", count=len(uk_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("UK datasets", total=len(uk_datasets)) + for name, pe_dataset in uk_datasets.items(): + progress.update(task, description=f"UK: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + upload_start = time.time() + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + console.print(f" Uploaded {pe_dataset.name} in {time.time() - upload_start:.1f}s") + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=uk_model.id, + ) + session.add(db_dataset) + session.commit() + created += 1 + progress.advance(task) + + console.print(f"[green]✓[/green] UK datasets: {created} created, {skipped} skipped") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_uk_model.py b/scripts/seed_uk_model.py new file mode 100644 index 0000000..07543bf --- /dev/null +++ b/scripts/seed_uk_model.py @@ -0,0 +1,33 @@ +"""Seed UK model (variables, parameters, parameter values).""" + +import argparse +import time + +from seed_common import console, get_session, seed_model + + +def main(): + parser = argparse.ArgumentParser(description="Seed UK model") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: skip state parameters", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.uk import uk_latest + + console.print("[bold green]Seeding UK model...[/bold green]\n") + + start = time.time() + with get_session() as session: + version = seed_model(uk_latest, session, lite=args.lite) + console.print(f"[green]✓[/green] UK model seeded: {version.id}") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_us_datasets.py b/scripts/seed_us_datasets.py new file mode 100644 index 0000000..abf1995 --- /dev/null +++ b/scripts/seed_us_datasets.py @@ -0,0 +1,108 @@ +"""Seed US datasets (CPS) and upload to S3.""" + +import argparse +import time +from pathlib import Path + +import logfire +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import select + +from seed_common import console, get_session + + +def main(): + parser = argparse.ArgumentParser(description="Seed US datasets") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: only seed CPS 2026", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.us.datasets import ( + ensure_datasets as ensure_us_datasets, + ) + + from policyengine_api.models import Dataset, TaxBenefitModel + from policyengine_api.services.storage import upload_dataset_for_seeding + + console.print("[bold green]Seeding US datasets...[/bold green]\n") + + start = time.time() + with get_session() as session: + # Get US model + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print("[red]Error: US model not found. Run seed_us_model.py first.[/red]") + return + + data_folder = str(Path(__file__).parent.parent / "data") + console.print(f" Data folder: {data_folder}") + + # Get datasets + console.print(" Loading US datasets from policyengine package...") + ds_start = time.time() + us_datasets = ensure_us_datasets(data_folder=data_folder) + console.print(f" Loaded {len(us_datasets)} datasets in {time.time() - ds_start:.1f}s") + + # In lite mode, only upload CPS 2026 + if args.lite: + us_datasets = { + k: v for k, v in us_datasets.items() if v.year == 2026 and "cps" in k + } + console.print(f" Lite mode: filtered to {len(us_datasets)} dataset(s)") + + created = 0 + skipped = 0 + + with logfire.span("seed_us_datasets", count=len(us_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("US datasets", total=len(us_datasets)) + for name, pe_dataset in us_datasets.items(): + progress.update(task, description=f"US: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + upload_start = time.time() + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + console.print(f" Uploaded {pe_dataset.name} in {time.time() - upload_start:.1f}s") + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + created += 1 + progress.advance(task) + + console.print(f"[green]✓[/green] US datasets: {created} created, {skipped} skipped") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_us_model.py b/scripts/seed_us_model.py new file mode 100644 index 0000000..ce8a829 --- /dev/null +++ b/scripts/seed_us_model.py @@ -0,0 +1,33 @@ +"""Seed US model (variables, parameters, parameter values).""" + +import argparse +import time + +from seed_common import console, get_session, seed_model + + +def main(): + parser = argparse.ArgumentParser(description="Seed US model") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: skip state parameters", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.us import us_latest + + console.print("[bold green]Seeding US model...[/bold green]\n") + + start = time.time() + with get_session() as session: + version = seed_model(us_latest, session, lite=args.lite) + console.print(f"[green]✓[/green] US model seeded: {version.id}") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_economy_simulation.py b/scripts/test_economy_simulation.py new file mode 100644 index 0000000..3845fc4 --- /dev/null +++ b/scripts/test_economy_simulation.py @@ -0,0 +1,277 @@ +"""Test economy-wide simulation following the exact flow from modal_app.py. + +This script mimics the economy-wide simulation code path as closely as possible +to verify whether policy reforms are being applied correctly. +""" + +import tempfile +from datetime import datetime +from pathlib import Path + +import pandas as pd +from microdf import MicroDataFrame + +# Import exactly as modal_app.py does +from policyengine.core import Simulation as PESimulation +from policyengine.core.policy import ParameterValue as PEParameterValue +from policyengine.core.policy import Policy as PEPolicy +from policyengine.tax_benefit_models.us import us_latest +from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset, USYearData + + +def create_test_dataset(year: int) -> PolicyEngineUSDataset: + """Create a small test dataset similar to what would be loaded from storage. + + Uses the same structure as economy-wide datasets but with just a few households. + """ + # Create 3 test households with different income levels + # Each household has 2 adults + 2 children (to test CTC) + n_households = 3 + n_people = n_households * 4 # 4 people per household + + # Person data + person_data = { + "person_id": list(range(n_people)), + "person_household_id": [i // 4 for i in range(n_people)], + "person_marital_unit_id": [], + "person_family_id": [i // 4 for i in range(n_people)], + "person_spm_unit_id": [i // 4 for i in range(n_people)], + "person_tax_unit_id": [i // 4 for i in range(n_people)], + "person_weight": [1000.0] * n_people, # Weight for population scaling + "age": [], + "employment_income": [], + } + + # Build person details + marital_unit_counter = 0 + for hh in range(n_households): + base_income = 10000 + (hh * 20000) # 10k, 30k, 50k + # Adult 1 + person_data["age"].append(35) + person_data["employment_income"].append(base_income) + person_data["person_marital_unit_id"].append(marital_unit_counter) + # Adult 2 + person_data["age"].append(33) + person_data["employment_income"].append(0) + person_data["person_marital_unit_id"].append(marital_unit_counter) + marital_unit_counter += 1 + # Child 1 + person_data["age"].append(5) + person_data["employment_income"].append(0) + person_data["person_marital_unit_id"].append(marital_unit_counter) + marital_unit_counter += 1 + # Child 2 + person_data["age"].append(3) + person_data["employment_income"].append(0) + person_data["person_marital_unit_id"].append(marital_unit_counter) + marital_unit_counter += 1 + + n_marital_units = marital_unit_counter + + # Entity data + household_data = { + "household_id": list(range(n_households)), + "household_weight": [1000.0] * n_households, + "state_fips": [48] * n_households, # Texas + } + + marital_unit_data = { + "marital_unit_id": list(range(n_marital_units)), + "marital_unit_weight": [1000.0] * n_marital_units, + } + + family_data = { + "family_id": list(range(n_households)), + "family_weight": [1000.0] * n_households, + } + + spm_unit_data = { + "spm_unit_id": list(range(n_households)), + "spm_unit_weight": [1000.0] * n_households, + } + + tax_unit_data = { + "tax_unit_id": list(range(n_households)), + "tax_unit_weight": [1000.0] * n_households, + } + + # Create MicroDataFrames (same as economy datasets) + person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") + household_df = MicroDataFrame(pd.DataFrame(household_data), weights="household_weight") + marital_unit_df = MicroDataFrame(pd.DataFrame(marital_unit_data), weights="marital_unit_weight") + family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight") + spm_unit_df = MicroDataFrame(pd.DataFrame(spm_unit_data), weights="spm_unit_weight") + tax_unit_df = MicroDataFrame(pd.DataFrame(tax_unit_data), weights="tax_unit_weight") + + # Create dataset file + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "test_economy.h5") + + return PolicyEngineUSDataset( + name="Test Economy Dataset", + description="Small test dataset for economy simulation", + filepath=filepath, + year=year, + data=USYearData( + person=person_df, + household=household_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + tax_unit=tax_unit_df, + ), + ) + + +def create_policy_like_modal_app(model_version) -> PEPolicy: + """Create a policy exactly like _get_pe_policy_us does in modal_app.py. + + This mimics the exact flow: + 1. Look up parameter by name from model_version.parameters + 2. Create PEParameterValue with the parameter, value, start_date, end_date + 3. Create PEPolicy with the parameter values + """ + param_lookup = {p.name: p for p in model_version.parameters} + + # This is exactly what _get_pe_policy_us does + pe_param = param_lookup.get("gov.irs.credits.ctc.refundable.fully_refundable") + if not pe_param: + raise ValueError("Parameter not found!") + + pe_pv = PEParameterValue( + parameter=pe_param, + value=True, # Make CTC fully refundable + start_date=datetime(2024, 1, 1), + end_date=None, + ) + + return PEPolicy( + name="CTC Fully Refundable", + description="Makes CTC fully refundable", + parameter_values=[pe_pv], + ) + + +def run_economy_simulation(dataset: PolicyEngineUSDataset, policy: PEPolicy | None, label: str) -> dict: + """Run an economy simulation exactly like modal_app.py does. + + This follows the exact flow from simulate_economy_us: + 1. Create PESimulation with dataset, model version, policy, dynamic + 2. Call pe_sim.ensure() (which calls run() internally) + 3. Access output via pe_sim.output_dataset + """ + print(f"\n=== {label} ===") + print(f" Policy: {policy.name if policy else 'None (baseline)'}") + if policy: + print(f" Policy parameter_values: {len(policy.parameter_values)}") + for pv in policy.parameter_values: + print(f" - {pv.parameter.name}: {pv.value} (start: {pv.start_date})") + + pe_model_version = us_latest + + # Create and run simulation - EXACTLY like modal_app.py lines 1006-1012 + pe_sim = PESimulation( + dataset=dataset, + tax_benefit_model_version=pe_model_version, + policy=policy, + dynamic=None, + ) + pe_sim.ensure() + + # Extract results from output dataset + output_data = pe_sim.output_dataset.data + + # Sum up key metrics across all tax units (weighted) + tax_unit_df = pd.DataFrame(output_data.tax_unit) + + # Get the variables we care about + total_ctc = 0 + total_income_tax = 0 + total_eitc = 0 + + for var in ["ctc", "income_tax", "eitc"]: + if var in tax_unit_df.columns: + # Weighted sum + weights = tax_unit_df.get("tax_unit_weight", pd.Series([1.0] * len(tax_unit_df))) + if var == "ctc": + total_ctc = (tax_unit_df[var] * weights).sum() + elif var == "income_tax": + total_income_tax = (tax_unit_df[var] * weights).sum() + elif var == "eitc": + total_eitc = (tax_unit_df[var] * weights).sum() + + print(f" Results (weighted totals across {len(tax_unit_df)} tax units):") + print(f" Total CTC: ${total_ctc:,.0f}") + print(f" Total Income Tax: ${total_income_tax:,.0f}") + print(f" Total EITC: ${total_eitc:,.0f}") + + # Also show per-household breakdown + print(f" Per tax unit breakdown:") + for i in range(len(tax_unit_df)): + ctc = tax_unit_df["ctc"].iloc[i] if "ctc" in tax_unit_df.columns else 0 + income_tax = tax_unit_df["income_tax"].iloc[i] if "income_tax" in tax_unit_df.columns else 0 + print(f" Tax Unit {i}: CTC=${ctc:,.0f}, Income Tax=${income_tax:,.0f}") + + return { + "total_ctc": total_ctc, + "total_income_tax": total_income_tax, + "total_eitc": total_eitc, + "tax_unit_df": tax_unit_df, + } + + +def main(): + print("=" * 60) + print("ECONOMY-WIDE SIMULATION TEST") + print("Following the exact code path from modal_app.py") + print("=" * 60) + + year = 2024 + + # Create test dataset (same for both simulations) + print("\nCreating test dataset...") + + # Run baseline simulation + baseline_dataset = create_test_dataset(year) + baseline_results = run_economy_simulation(baseline_dataset, None, "BASELINE (no policy)") + + # Create policy exactly like modal_app.py does + policy = create_policy_like_modal_app(us_latest) + + # Run reform simulation + reform_dataset = create_test_dataset(year) + reform_results = run_economy_simulation(reform_dataset, policy, "REFORM (CTC fully refundable)") + + # Compare results + print("\n" + "=" * 60) + print("COMPARISON") + print("=" * 60) + + ctc_diff = reform_results["total_ctc"] - baseline_results["total_ctc"] + tax_diff = reform_results["total_income_tax"] - baseline_results["total_income_tax"] + + print(f"\nTotal CTC:") + print(f" Baseline: ${baseline_results['total_ctc']:,.0f}") + print(f" Reform: ${reform_results['total_ctc']:,.0f}") + print(f" Change: ${ctc_diff:,.0f}") + + print(f"\nTotal Income Tax:") + print(f" Baseline: ${baseline_results['total_income_tax']:,.0f}") + print(f" Reform: ${reform_results['total_income_tax']:,.0f}") + print(f" Change: ${tax_diff:,.0f}") + + # Verdict + print("\n" + "=" * 60) + print("VERDICT") + print("=" * 60) + + if baseline_results["total_income_tax"] == reform_results["total_income_tax"]: + print("\n❌ BUG CONFIRMED: Results are IDENTICAL!") + print(" The policy reform is NOT being applied to economy simulations.") + else: + print("\n✓ NO BUG: Results differ as expected!") + print(f" The fully refundable CTC reform changed income tax by ${tax_diff:,.0f}") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_household_impact.py b/scripts/test_household_impact.py new file mode 100644 index 0000000..81c85b0 --- /dev/null +++ b/scripts/test_household_impact.py @@ -0,0 +1,135 @@ +"""Test household impact analysis end-to-end. + +This script tests the async household impact analysis workflow: +1. Create a stored household +2. Run household impact analysis (returns immediately with report_id) +3. Poll until completed +4. Verify results + +Usage: + uv run python scripts/test_household_impact.py +""" + +import sys +import time + +import requests + +BASE_URL = "http://127.0.0.1:8000" + + +def main(): + print("=" * 60) + print("Testing Household Impact Analysis (Async)") + print("=" * 60) + + # Step 1: Create a US household + print("\n1. Creating US household...") + household_data = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "label": "Test household for impact analysis", + "people": [ + { + "age": 35, + "employment_income": 50000, + } + ], + "tax_unit": {}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_code": "NV"}, + } + + resp = requests.post(f"{BASE_URL}/households/", json=household_data) + if resp.status_code != 201: + print(f" FAILED: {resp.status_code} - {resp.text}") + sys.exit(1) + + household = resp.json() + household_id = household["id"] + print(f" Created household: {household_id}") + + # Step 2: Run household impact analysis + print("\n2. Starting household impact analysis...") + impact_request = { + "household_id": household_id, + "policy_id": None, # Single run under current law + } + + resp = requests.post(f"{BASE_URL}/analysis/household-impact", json=impact_request) + if resp.status_code != 200: + print(f" FAILED: {resp.status_code} - {resp.text}") + sys.exit(1) + + result = resp.json() + report_id = result["report_id"] + status = result["status"] + print(f" Report ID: {report_id}") + print(f" Initial status: {status}") + + # Step 3: Poll until completed + print("\n3. Polling for results...") + max_attempts = 30 + for attempt in range(max_attempts): + resp = requests.get(f"{BASE_URL}/analysis/household-impact/{report_id}") + if resp.status_code != 200: + print(f" FAILED: {resp.status_code} - {resp.text}") + sys.exit(1) + + result = resp.json() + status = result["status"].upper() # Normalize to uppercase + print(f" Attempt {attempt + 1}: status={status}") + + if status == "COMPLETED": + break + elif status == "FAILED": + print(f" FAILED: {result.get('error_message', 'Unknown error')}") + sys.exit(1) + + time.sleep(0.5) + else: + print(f" FAILED: Timed out after {max_attempts} attempts") + sys.exit(1) + + # Step 4: Verify results + print("\n4. Verifying results...") + baseline_result = result.get("baseline_result") + if not baseline_result: + print(" FAILED: No baseline result") + sys.exit(1) + + print(f" Baseline result keys: {list(baseline_result.keys())}") + + # Check for expected entity types + expected_entities = ["person", "tax_unit", "spm_unit", "family", "marital_unit", "household"] + for entity in expected_entities: + if entity in baseline_result: + print(f" ✓ {entity}: {len(baseline_result[entity])} entities") + else: + print(f" ✗ {entity}: missing") + + # Look for net_income in person output + if "person" in baseline_result and baseline_result["person"]: + person = baseline_result["person"][0] + if "household_net_income" in person: + print(f" household_net_income: ${person['household_net_income']:,.2f}") + elif "spm_unit_net_income" in person: + print(f" spm_unit_net_income: ${person['spm_unit_net_income']:,.2f}") + + # Step 5: Cleanup - delete household + print("\n5. Cleaning up...") + resp = requests.delete(f"{BASE_URL}/households/{household_id}") + if resp.status_code == 204: + print(f" Deleted household: {household_id}") + else: + print(f" Warning: Failed to delete household: {resp.status_code}") + + print("\n" + "=" * 60) + print("SUCCESS: Household impact analysis working correctly!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scripts/test_household_scenarios.py b/scripts/test_household_scenarios.py new file mode 100644 index 0000000..fb418a4 --- /dev/null +++ b/scripts/test_household_scenarios.py @@ -0,0 +1,344 @@ +"""Test household calculation scenarios. + +Tests: +1. US California household under current law +2. Scotland household under current law +3. US household: current law vs CTC fully refundable reform +""" + +import sys +import time +import requests + +BASE_URL = "http://127.0.0.1:8000" + + +def poll_for_completion(report_id: str, max_attempts: int = 60) -> dict: + """Poll until report is completed or failed.""" + for attempt in range(max_attempts): + resp = requests.get(f"{BASE_URL}/analysis/household-impact/{report_id}") + if resp.status_code != 200: + raise Exception(f"Failed to get report: {resp.status_code} - {resp.text}") + + result = resp.json() + status = result["status"].upper() + + if status == "COMPLETED": + return result + elif status == "FAILED": + raise Exception(f"Report failed: {result.get('error_message', 'Unknown error')}") + + time.sleep(0.5) + + raise Exception(f"Timed out after {max_attempts} attempts") + + +def print_household_summary(result: dict, label: str): + """Print summary of household calculation result.""" + print(f"\n {label}:") + + baseline = result.get("baseline_result", {}) + reform = result.get("reform_result", {}) + + # Get key metrics from person/household + if "person" in baseline and baseline["person"]: + person = baseline["person"][0] + if "household_net_income" in person: + baseline_income = person["household_net_income"] + print(f" Baseline net income: ${baseline_income:,.2f}") + + if reform and "person" in reform and reform["person"]: + reform_income = reform["person"][0].get("household_net_income", 0) + print(f" Reform net income: ${reform_income:,.2f}") + print(f" Difference: ${reform_income - baseline_income:,.2f}") + + # Show some tax/benefit info if available + for key in ["income_tax", "federal_income_tax", "state_income_tax", "ctc", "refundable_ctc"]: + if key in person: + print(f" {key}: ${person[key]:,.2f}") + + +def test_us_california(): + """Test 1: US California household under current law.""" + print("\n" + "=" * 60) + print("TEST 1: US California Household - Current Law") + print("=" * 60) + + # Create California household + household_data = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "label": "California test household", + "people": [ + {"age": 35, "employment_income": 75000}, + {"age": 33, "employment_income": 45000}, + {"age": 8}, # Child + ], + "tax_unit": {}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_code": "CA"}, + } + + print("\n Creating household...") + resp = requests.post(f"{BASE_URL}/households/", json=household_data) + if resp.status_code != 201: + print(f" FAILED: {resp.status_code} - {resp.text}") + return None + + household = resp.json() + household_id = household["id"] + print(f" Household ID: {household_id}") + + # Run analysis under current law (no policy_id) + print(" Running analysis...") + resp = requests.post(f"{BASE_URL}/analysis/household-impact", json={ + "household_id": household_id, + "policy_id": None, + }) + + if resp.status_code != 200: + print(f" FAILED: {resp.status_code} - {resp.text}") + return household_id + + report_id = resp.json()["report_id"] + print(f" Report ID: {report_id}") + + # Poll for results + try: + result = poll_for_completion(report_id) + print(" Status: COMPLETED") + print_household_summary(result, "Results") + except Exception as e: + print(f" FAILED: {e}") + + return household_id + + +def test_scotland(): + """Test 2: Scotland household under current law.""" + print("\n" + "=" * 60) + print("TEST 2: Scotland Household - Current Law") + print("=" * 60) + + # Create Scotland household + household_data = { + "tax_benefit_model_name": "policyengine_uk", + "year": 2024, + "label": "Scotland test household", + "people": [ + {"age": 40, "employment_income": 45000}, + ], + "benunit": {}, + "household": {"region": "SCOTLAND"}, + } + + print("\n Creating household...") + resp = requests.post(f"{BASE_URL}/households/", json=household_data) + if resp.status_code != 201: + print(f" FAILED: {resp.status_code} - {resp.text}") + return None + + household = resp.json() + household_id = household["id"] + print(f" Household ID: {household_id}") + + # Run analysis under current law + print(" Running analysis...") + resp = requests.post(f"{BASE_URL}/analysis/household-impact", json={ + "household_id": household_id, + "policy_id": None, + }) + + if resp.status_code != 200: + print(f" FAILED: {resp.status_code} - {resp.text}") + return household_id + + report_id = resp.json()["report_id"] + print(f" Report ID: {report_id}") + + # Poll for results + try: + result = poll_for_completion(report_id) + print(" Status: COMPLETED") + print_household_summary(result, "Results") + except Exception as e: + print(f" FAILED: {e}") + + return household_id + + +def test_us_ctc_reform(): + """Test 3: US household - current law vs CTC fully refundable.""" + print("\n" + "=" * 60) + print("TEST 3: US Household - Current Law vs CTC Fully Refundable") + print("=" * 60) + + # First, find the CTC refundability parameter + print("\n Finding CTC refundability parameter...") + resp = requests.get(f"{BASE_URL}/parameters", params={"search": "ctc", "limit": 50}) + if resp.status_code != 200: + print(f" FAILED to search parameters: {resp.status_code}") + return None, None + + params = resp.json() + ctc_param = None + for p in params: + # Look for the refundable portion parameter + if "refundable" in p["name"].lower() and "ctc" in p["name"].lower(): + print(f" Found: {p['name']} (label: {p.get('label')})") + ctc_param = p + break + + if not ctc_param: + # Try searching for child tax credit parameters + print(" Searching for child_tax_credit parameters...") + resp = requests.get(f"{BASE_URL}/parameters", params={"search": "child_tax_credit", "limit": 50}) + params = resp.json() + for p in params: + print(f" - {p['name']}") + if "refundable" in p["name"].lower(): + ctc_param = p + break + + if not ctc_param: + print(" Could not find CTC refundability parameter") + print(" Continuing with household creation anyway...") + + # Create household with children (needed for CTC) + household_data = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "label": "CTC test household", + "people": [ + {"age": 35, "employment_income": 30000}, # Lower income to see CTC effect + {"age": 33, "employment_income": 0}, + {"age": 5}, # Child 1 + {"age": 3}, # Child 2 + ], + "tax_unit": {}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_code": "TX"}, # Texas - no state income tax + } + + print("\n Creating household...") + resp = requests.post(f"{BASE_URL}/households/", json=household_data) + if resp.status_code != 201: + print(f" FAILED: {resp.status_code} - {resp.text}") + return None, None + + household = resp.json() + household_id = household["id"] + print(f" Household ID: {household_id}") + + # Create a policy that makes CTC fully refundable + policy_id = None + if ctc_param: + print("\n Creating CTC fully refundable policy...") + policy_data = { + "name": "CTC Fully Refundable", + "description": "Makes the Child Tax Credit fully refundable", + } + resp = requests.post(f"{BASE_URL}/policies/", json=policy_data) + if resp.status_code == 201: + policy = resp.json() + policy_id = policy["id"] + print(f" Policy ID: {policy_id}") + + # Add parameter value to make CTC fully refundable + # The parameter should set refundable portion to 100% or max amount + pv_data = { + "parameter_id": ctc_param["id"], + "value_json": 1.0, # 100% refundable + "start_date": "2024-01-01T00:00:00Z", + "end_date": None, + "policy_id": policy_id, + } + resp = requests.post(f"{BASE_URL}/parameter-values/", json=pv_data) + if resp.status_code == 201: + print(" Added parameter value for full refundability") + else: + print(f" Warning: Failed to add parameter value: {resp.status_code} - {resp.text}") + else: + print(f" Warning: Failed to create policy: {resp.status_code}") + + # Run analysis with reform policy + print("\n Running analysis (baseline vs reform)...") + resp = requests.post(f"{BASE_URL}/analysis/household-impact", json={ + "household_id": household_id, + "policy_id": policy_id, + }) + + if resp.status_code != 200: + print(f" FAILED: {resp.status_code} - {resp.text}") + return household_id, policy_id + + report_id = resp.json()["report_id"] + print(f" Report ID: {report_id}") + + # Poll for results + try: + result = poll_for_completion(report_id) + print(" Status: COMPLETED") + print_household_summary(result, "Results") + except Exception as e: + print(f" FAILED: {e}") + + return household_id, policy_id + + +def main(): + print("=" * 60) + print("HOUSEHOLD CALCULATION SCENARIO TESTS") + print("=" * 60) + + # Track created resources for cleanup + households = [] + policies = [] + + # Test 1: US California + hh_id = test_us_california() + if hh_id: + households.append(hh_id) + + # Test 2: Scotland + hh_id = test_scotland() + if hh_id: + households.append(hh_id) + + # Test 3: CTC Reform + hh_id, policy_id = test_us_ctc_reform() + if hh_id: + households.append(hh_id) + if policy_id: + policies.append(policy_id) + + # Cleanup + print("\n" + "=" * 60) + print("CLEANUP") + print("=" * 60) + + for hh_id in households: + resp = requests.delete(f"{BASE_URL}/households/{hh_id}") + if resp.status_code == 204: + print(f" Deleted household: {hh_id}") + else: + print(f" Warning: Failed to delete household {hh_id}: {resp.status_code}") + + for policy_id in policies: + resp = requests.delete(f"{BASE_URL}/policies/{policy_id}") + if resp.status_code == 204: + print(f" Deleted policy: {policy_id}") + else: + print(f" Warning: Failed to delete policy {policy_id}: {resp.status_code}") + + print("\n" + "=" * 60) + print("TESTS COMPLETE") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index 881af99..c3e0353 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -9,6 +9,8 @@ datasets, dynamics, household, + household_analysis, + households, outputs, parameter_values, parameters, @@ -16,6 +18,7 @@ simulations, tax_benefit_model_versions, tax_benefit_models, + user_household_associations, variables, ) @@ -33,7 +36,10 @@ api_router.include_router(tax_benefit_model_versions.router) api_router.include_router(change_aggregates.router) api_router.include_router(household.router) +api_router.include_router(household_analysis.router) +api_router.include_router(households.router) api_router.include_router(analysis.router) api_router.include_router(agent.router) +api_router.include_router(user_household_associations.router) __all__ = ["api_router"] diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index c9aa86d..10e6fc5 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -35,6 +35,7 @@ ReportStatus, Simulation, SimulationStatus, + SimulationType, TaxBenefitModel, TaxBenefitModelVersion, ) @@ -138,19 +139,24 @@ def _get_model_version( def _get_deterministic_simulation_id( - dataset_id: UUID, + simulation_type: SimulationType, model_version_id: UUID, policy_id: UUID | None, dynamic_id: UUID | None, + dataset_id: UUID | None = None, + household_id: UUID | None = None, ) -> UUID: """Generate a deterministic UUID from simulation parameters.""" - key = f"{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}" + if simulation_type == SimulationType.ECONOMY: + key = f"economy:{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}" + else: + key = f"household:{household_id}:{model_version_id}:{policy_id}:{dynamic_id}" return uuid5(SIMULATION_NAMESPACE, key) def _get_deterministic_report_id( baseline_sim_id: UUID, - reform_sim_id: UUID, + reform_sim_id: UUID | None, ) -> UUID: """Generate a deterministic UUID from report parameters.""" key = f"{baseline_sim_id}:{reform_sim_id}" @@ -158,15 +164,22 @@ def _get_deterministic_report_id( def _get_or_create_simulation( - dataset_id: UUID, + simulation_type: SimulationType, model_version_id: UUID, policy_id: UUID | None, dynamic_id: UUID | None, session: Session, + dataset_id: UUID | None = None, + household_id: UUID | None = None, ) -> Simulation: """Get existing simulation or create a new one.""" sim_id = _get_deterministic_simulation_id( - dataset_id, model_version_id, policy_id, dynamic_id + simulation_type, + model_version_id, + policy_id, + dynamic_id, + dataset_id=dataset_id, + household_id=household_id, ) existing = session.get(Simulation, sim_id) @@ -175,7 +188,9 @@ def _get_or_create_simulation( simulation = Simulation( id=sim_id, + simulation_type=simulation_type, dataset_id=dataset_id, + household_id=household_id, tax_benefit_model_version_id=model_version_id, policy_id=policy_id, dynamic_id=dynamic_id, @@ -189,8 +204,9 @@ def _get_or_create_simulation( def _get_or_create_report( baseline_sim_id: UUID, - reform_sim_id: UUID, + reform_sim_id: UUID | None, label: str, + report_type: str, session: Session, ) -> Report: """Get existing report or create a new one.""" @@ -203,6 +219,7 @@ def _get_or_create_report( report = Report( id=report_id, label=label, + report_type=report_type, baseline_simulation_id=baseline_sim_id, reform_simulation_id=reform_sim_id, status=ReportStatus.PENDING, @@ -580,19 +597,21 @@ def economic_impact( # Get or create simulations baseline_sim = _get_or_create_simulation( - dataset_id=request.dataset_id, + simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, policy_id=None, dynamic_id=request.dynamic_id, session=session, + dataset_id=request.dataset_id, ) reform_sim = _get_or_create_simulation( - dataset_id=request.dataset_id, + simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, policy_id=request.policy_id, dynamic_id=request.dynamic_id, session=session, + dataset_id=request.dataset_id, ) # Get or create report @@ -600,7 +619,9 @@ def economic_impact( if request.policy_id: label += f" (policy {request.policy_id})" - report = _get_or_create_report(baseline_sim.id, reform_sim.id, label, session) + report = _get_or_create_report( + baseline_sim.id, reform_sim.id, label, "economy_comparison", session + ) # Trigger computation if report is pending if report.status == ReportStatus.PENDING: diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index 0e89b5e..adb6ac9 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -294,17 +294,16 @@ def _calculate_household_uk( Supports multiple households via entity relational dataframes. If entity IDs are not provided, defaults to single household with all people in it. - """ - import tempfile - from datetime import datetime - from pathlib import Path + Uses policyengine-uk Microsimulation directly with reform dict to ensure + policy changes are applied correctly. + """ + import numpy as np import pandas as pd - from policyengine.core import Simulation - from microdf import MicroDataFrame from policyengine.tax_benefit_models.uk import uk_latest - from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset - from policyengine.tax_benefit_models.uk.datasets import UKYearData + from policyengine_core.simulations.simulation_builder import SimulationBuilder + from policyengine_uk import Microsimulation + from policyengine_uk.system import system n_people = len(people) n_benunits = max(1, len(benunit)) @@ -350,68 +349,88 @@ def _calculate_household_uk( household_data[key] = [0.0] * n_households household_data[key][i] = value - # Create MicroDataFrames - person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") - benunit_df = MicroDataFrame(pd.DataFrame(benunit_data), weights="benunit_weight") - household_df = MicroDataFrame( - pd.DataFrame(household_data), weights="household_weight" + # Convert policy_data to policyengine-uk reform dict format + # Format: {"param.name": {"YYYY-MM-DD": value}} + reform = None + if policy_data and policy_data.get("parameter_values"): + reform = {} + for pv in policy_data["parameter_values"]: + param_name = pv.get("parameter_name") + value = pv.get("value") + start_date = pv.get("start_date") + + if param_name and start_date: + # Parse ISO date string to get just the date part + if "T" in start_date: + date_str = start_date.split("T")[0] + else: + date_str = start_date + + if param_name not in reform: + reform[param_name] = {} + reform[param_name][date_str] = value + + # Create Microsimulation with reform applied at construction time + microsim = Microsimulation(reform=reform) + + # Build simulation from entity data using SimulationBuilder + person_df = pd.DataFrame(person_data) + + # Determine column naming convention + benunit_id_col = ( + "person_benunit_id" + if "person_benunit_id" in person_df.columns + else "benunit_id" ) - - # Create temporary dataset - tmpdir = tempfile.mkdtemp() - filepath = str(Path(tmpdir) / "household_calc.h5") - - dataset = PolicyEngineUKDataset( - name="Household calculation", - description="Household(s) for calculation", - filepath=filepath, - year=year, - data=UKYearData( - person=person_df, - benunit=benunit_df, - household=household_df, - ), + household_id_col = ( + "person_household_id" + if "person_household_id" in person_df.columns + else "household_id" ) - # Build policy if provided - policy = None - if policy_data: - from policyengine.core.policy import ParameterValue as PEParameterValue - from policyengine.core.policy import Policy as PEPolicy - - pe_param_values = [] - param_lookup = {p.name: p for p in uk_latest.parameters} - for pv in policy_data.get("parameter_values", []): - pe_param = param_lookup.get(pv["parameter_name"]) - if pe_param: - pe_pv = PEParameterValue( - parameter=pe_param, - value=pv["value"], - start_date=datetime.fromisoformat(pv["start_date"]) - if pv.get("start_date") - else None, - end_date=datetime.fromisoformat(pv["end_date"]) - if pv.get("end_date") - else None, - ) - pe_param_values.append(pe_pv) - policy = PEPolicy( - name=policy_data.get("name", ""), - description=policy_data.get("description", ""), - parameter_values=pe_param_values, - ) + # Declare entities using SimulationBuilder + builder = SimulationBuilder() + builder.populations = system.instantiate_entities() + + builder.declare_person_entity("person", person_df["person_id"].values) + builder.declare_entity("benunit", np.unique(person_df[benunit_id_col].values)) + builder.declare_entity("household", np.unique(person_df[household_id_col].values)) - # Run simulation - simulation = Simulation( - dataset=dataset, - tax_benefit_model_version=uk_latest, - policy=policy, + # Join persons to group entities + builder.join_with_persons( + builder.populations["benunit"], + person_df[benunit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["household"], + person_df[household_id_col].values, + np.array(["member"] * len(person_df)), ) - simulation.run() - # Extract outputs - output_data = simulation.output_dataset.data + # Build simulation from populations + microsim.build_from_populations(builder.populations) + # Set input variables for each entity + id_columns = { + "person_id", + "benunit_id", + "person_benunit_id", + "household_id", + "person_household_id", + } + + for entity_name, entity_df in [ + ("person", person_data), + ("benunit", benunit_data), + ("household", household_data), + ]: + df = pd.DataFrame(entity_df) + for column in df.columns: + if column not in id_columns and column in system.variables: + microsim.set_input(column, year, df[column].values) + + # Calculate output variables def safe_convert(value): try: return float(value) @@ -422,21 +441,24 @@ def safe_convert(value): for i in range(n_people): person_dict = {} for var in uk_latest.entity_variables["person"]: - person_dict[var] = safe_convert(output_data.person[var].iloc[i]) + val = microsim.calculate(var, period=year, map_to="person") + person_dict[var] = safe_convert(val.values[i]) person_outputs.append(person_dict) benunit_outputs = [] - for i in range(len(output_data.benunit)): + for i in range(n_benunits): benunit_dict = {} for var in uk_latest.entity_variables["benunit"]: - benunit_dict[var] = safe_convert(output_data.benunit[var].iloc[i]) + val = microsim.calculate(var, period=year, map_to="benunit") + benunit_dict[var] = safe_convert(val.values[i]) benunit_outputs.append(benunit_dict) household_outputs = [] - for i in range(len(output_data.household)): + for i in range(n_households): household_dict = {} for var in uk_latest.entity_variables["household"]: - household_dict[var] = safe_convert(output_data.household[var].iloc[i]) + val = microsim.calculate(var, period=year, map_to="household") + household_dict[var] = safe_convert(val.values[i]) household_outputs.append(household_dict) return { @@ -466,7 +488,14 @@ def _run_local_household_us( try: result = _calculate_household_us( - people, marital_unit, family, spm_unit, tax_unit, household, year, policy_data + people, + marital_unit, + family, + spm_unit, + tax_unit, + household, + year, + policy_data, ) # Update job with result @@ -506,17 +535,16 @@ def _calculate_household_us( Supports multiple households via entity relational dataframes. If entity IDs are not provided, defaults to single household with all people in it. - """ - import tempfile - from datetime import datetime - from pathlib import Path + Uses policyengine-us Microsimulation directly with reform dict to ensure + policy changes are applied correctly. + """ + import numpy as np import pandas as pd - from policyengine.core import Simulation - from microdf import MicroDataFrame from policyengine.tax_benefit_models.us import us_latest - from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset - from policyengine.tax_benefit_models.us.datasets import USYearData + from policyengine_core.simulations.simulation_builder import SimulationBuilder + from policyengine_us import Microsimulation + from policyengine_us.system import system n_people = len(people) n_households = max(1, len(household)) @@ -596,108 +624,158 @@ def _calculate_household_us( tax_unit_data[key] = [0.0] * n_tax_units tax_unit_data[key][i] = value - # Create MicroDataFrames - person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") - household_df = MicroDataFrame( - pd.DataFrame(household_data), weights="household_weight" + # Convert policy_data to policyengine-us reform dict format + # Format: {"param.name": {"YYYY-MM-DD": value}} + reform = None + if policy_data and policy_data.get("parameter_values"): + reform = {} + for pv in policy_data["parameter_values"]: + param_name = pv.get("parameter_name") + value = pv.get("value") + start_date = pv.get("start_date") + + if param_name and start_date: + # Parse ISO date string to get just the date part + if "T" in start_date: + date_str = start_date.split("T")[0] + else: + date_str = start_date + + if param_name not in reform: + reform[param_name] = {} + reform[param_name][date_str] = value + + # Create Microsimulation with reform applied at construction time + # This ensures the reform is properly integrated into the tax benefit system + microsim = Microsimulation(reform=reform) + + # Build simulation from entity data using SimulationBuilder + person_df = pd.DataFrame(person_data) + + # Determine column naming convention + household_id_col = ( + "person_household_id" + if "person_household_id" in person_df.columns + else "household_id" ) - marital_unit_df = MicroDataFrame( - pd.DataFrame(marital_unit_data), weights="marital_unit_weight" + marital_unit_id_col = ( + "person_marital_unit_id" + if "person_marital_unit_id" in person_df.columns + else "marital_unit_id" ) - family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight") - spm_unit_df = MicroDataFrame(pd.DataFrame(spm_unit_data), weights="spm_unit_weight") - tax_unit_df = MicroDataFrame(pd.DataFrame(tax_unit_data), weights="tax_unit_weight") - - # Create temporary dataset - tmpdir = tempfile.mkdtemp() - filepath = str(Path(tmpdir) / "household_calc.h5") - - dataset = PolicyEngineUSDataset( - name="Household calculation", - description="Household(s) for calculation", - filepath=filepath, - year=year, - data=USYearData( - person=person_df, - household=household_df, - marital_unit=marital_unit_df, - family=family_df, - spm_unit=spm_unit_df, - tax_unit=tax_unit_df, - ), + family_id_col = ( + "person_family_id" if "person_family_id" in person_df.columns else "family_id" + ) + spm_unit_id_col = ( + "person_spm_unit_id" + if "person_spm_unit_id" in person_df.columns + else "spm_unit_id" + ) + tax_unit_id_col = ( + "person_tax_unit_id" + if "person_tax_unit_id" in person_df.columns + else "tax_unit_id" ) - # Build policy if provided - policy = None - if policy_data: - from policyengine.core.policy import ParameterValue as PEParameterValue - from policyengine.core.policy import Policy as PEPolicy - - pe_param_values = [] - param_lookup = {p.name: p for p in us_latest.parameters} - for pv in policy_data.get("parameter_values", []): - pe_param = param_lookup.get(pv["parameter_name"]) - if pe_param: - pe_pv = PEParameterValue( - parameter=pe_param, - value=pv["value"], - start_date=datetime.fromisoformat(pv["start_date"]) - if pv.get("start_date") - else None, - end_date=datetime.fromisoformat(pv["end_date"]) - if pv.get("end_date") - else None, - ) - pe_param_values.append(pe_pv) - policy = PEPolicy( - name=policy_data.get("name", ""), - description=policy_data.get("description", ""), - parameter_values=pe_param_values, - ) + # Declare entities using SimulationBuilder + builder = SimulationBuilder() + builder.populations = system.instantiate_entities() + + builder.declare_person_entity("person", person_df["person_id"].values) + builder.declare_entity("household", np.unique(person_df[household_id_col].values)) + builder.declare_entity("spm_unit", np.unique(person_df[spm_unit_id_col].values)) + builder.declare_entity("family", np.unique(person_df[family_id_col].values)) + builder.declare_entity("tax_unit", np.unique(person_df[tax_unit_id_col].values)) + builder.declare_entity( + "marital_unit", np.unique(person_df[marital_unit_id_col].values) + ) - # Run simulation - simulation = Simulation( - dataset=dataset, - tax_benefit_model_version=us_latest, - policy=policy, + # Join persons to group entities + builder.join_with_persons( + builder.populations["household"], + person_df[household_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["spm_unit"], + person_df[spm_unit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["family"], + person_df[family_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["tax_unit"], + person_df[tax_unit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["marital_unit"], + person_df[marital_unit_id_col].values, + np.array(["member"] * len(person_df)), ) - simulation.run() - # Extract outputs - output_data = simulation.output_dataset.data + # Build simulation from populations + microsim.build_from_populations(builder.populations) + + # Set input variables for each entity + id_columns = { + "person_id", + "household_id", + "person_household_id", + "spm_unit_id", + "person_spm_unit_id", + "family_id", + "person_family_id", + "tax_unit_id", + "person_tax_unit_id", + "marital_unit_id", + "person_marital_unit_id", + } + for entity_name, entity_df in [ + ("person", person_data), + ("household", household_data), + ("spm_unit", spm_unit_data), + ("family", family_data), + ("tax_unit", tax_unit_data), + ("marital_unit", marital_unit_data), + ]: + df = pd.DataFrame(entity_df) + for column in df.columns: + if column not in id_columns and column in system.variables: + microsim.set_input(column, year, df[column].values) + + # Calculate output variables def safe_convert(value): try: return float(value) except (ValueError, TypeError): return str(value) - def extract_entity_outputs(entity_name: str, entity_data, n_rows: int) -> list[dict]: + def extract_entity_outputs( + entity_name: str, n_rows: int, map_to: str + ) -> list[dict]: outputs = [] for i in range(n_rows): row_dict = {} for var in us_latest.entity_variables[entity_name]: - row_dict[var] = safe_convert(entity_data[var].iloc[i]) + val = microsim.calculate(var, period=year, map_to=map_to) + row_dict[var] = safe_convert(val.values[i]) outputs.append(row_dict) return outputs return { - "person": extract_entity_outputs("person", output_data.person, n_people), + "person": extract_entity_outputs("person", n_people, "person"), "marital_unit": extract_entity_outputs( - "marital_unit", output_data.marital_unit, len(output_data.marital_unit) - ), - "family": extract_entity_outputs( - "family", output_data.family, len(output_data.family) - ), - "spm_unit": extract_entity_outputs( - "spm_unit", output_data.spm_unit, len(output_data.spm_unit) - ), - "tax_unit": extract_entity_outputs( - "tax_unit", output_data.tax_unit, len(output_data.tax_unit) - ), - "household": extract_entity_outputs( - "household", output_data.household, len(output_data.household) + "marital_unit", n_marital_units, "marital_unit" ), + "family": extract_entity_outputs("family", n_families, "family"), + "spm_unit": extract_entity_outputs("spm_unit", n_spm_units, "spm_unit"), + "tax_unit": extract_entity_outputs("tax_unit", n_tax_units, "tax_unit"), + "household": extract_entity_outputs("household", n_households, "household"), } diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py new file mode 100644 index 0000000..5f36fda --- /dev/null +++ b/src/policyengine_api/api/household_analysis.py @@ -0,0 +1,726 @@ +"""Household impact analysis endpoints. + +Use these endpoints to analyze household-level effects of policy reforms. +Supports single runs (current law) and comparisons (baseline vs reform). + +WORKFLOW: +1. Create a stored household: POST /households +2. Optionally create a reform policy: POST /policies +3. Run analysis: POST /analysis/household-impact (returns report_id) +4. Poll GET /analysis/household-impact/{report_id} until status="completed" +5. Results include baseline_result, reform_result (if comparison), and impact diff +""" + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Protocol +from uuid import UUID + +import logfire +from fastapi import APIRouter, Depends, HTTPException +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from pydantic import BaseModel, Field +from sqlmodel import Session + +from policyengine_api.models import ( + Household, + Policy, + Report, + ReportStatus, + Simulation, + SimulationStatus, + SimulationType, +) +from policyengine_api.services.database import get_session + +from .analysis import ( + _get_model_version, + _get_or_create_report, + _get_or_create_simulation, +) + + +def get_traceparent() -> str | None: + """Get the current W3C traceparent header for distributed tracing.""" + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + return carrier.get("traceparent") + + +router = APIRouter(prefix="/analysis", tags=["analysis"]) + + +# ============================================================================= +# Country Strategy Pattern +# ============================================================================= + + +@dataclass(frozen=True) +class CountryConfig: + """Configuration for a country's household calculation.""" + + name: str + entity_types: tuple[str, ...] + + +UK_CONFIG = CountryConfig( + name="uk", + entity_types=("person", "benunit", "household"), +) + +US_CONFIG = CountryConfig( + name="us", + entity_types=( + "person", + "tax_unit", + "spm_unit", + "family", + "marital_unit", + "household", + ), +) + + +def get_country_config(tax_benefit_model_name: str) -> CountryConfig: + """Get country configuration from model name.""" + if tax_benefit_model_name == "policyengine_uk": + return UK_CONFIG + return US_CONFIG + + +class HouseholdCalculator(Protocol): + """Protocol for country-specific household calculators.""" + + def __call__( + self, + household_data: dict[str, Any], + year: int, + policy_data: dict | None, + ) -> dict: ... + + +def calculate_uk_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Calculate UK household using the existing implementation.""" + from policyengine_api.api.household import _calculate_household_uk + + return _calculate_household_uk( + people=household_data.get("people", []), + benunit=_ensure_list(household_data.get("benunit")), + household=_ensure_list(household_data.get("household")), + year=year, + policy_data=policy_data, + ) + + +def calculate_us_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Calculate US household using the existing implementation.""" + from policyengine_api.api.household import _calculate_household_us + + return _calculate_household_us( + people=household_data.get("people", []), + marital_unit=_ensure_list(household_data.get("marital_unit")), + family=_ensure_list(household_data.get("family")), + spm_unit=_ensure_list(household_data.get("spm_unit")), + tax_unit=_ensure_list(household_data.get("tax_unit")), + household=_ensure_list(household_data.get("household")), + year=year, + policy_data=policy_data, + ) + + +def get_calculator(tax_benefit_model_name: str) -> HouseholdCalculator: + """Get the appropriate calculator for a country.""" + if tax_benefit_model_name == "policyengine_uk": + return calculate_uk_household + return calculate_us_household + + +# ============================================================================= +# Data Transformation Helpers +# ============================================================================= + + +def _ensure_list(value: Any) -> list: + """Ensure value is a list; wrap dict in list if needed.""" + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +def _extract_policy_data(policy: Policy | None) -> dict | None: + """Extract policy data from a Policy model into calculation format. + + Returns format expected by _calculate_household_us/_calculate_household_uk: + { + "name": "policy name", + "description": "policy description", + "parameter_values": [ + { + "parameter_name": "gov.irs.credits.ctc...", + "value": 0.16, + "start_date": "2024-01-01T00:00:00+00:00", + "end_date": null + } + ] + } + """ + if not policy or not policy.parameter_values: + return None + + parameter_values = [] + for pv in policy.parameter_values: + if not pv.parameter: + continue + + parameter_values.append({ + "parameter_name": pv.parameter.name, + "value": _extract_value(pv.value_json), + "start_date": _format_date(pv.start_date), + "end_date": _format_date(pv.end_date), + }) + + if not parameter_values: + return None + + return { + "name": policy.name, + "description": policy.description or "", + "parameter_values": parameter_values, + } + + +def _extract_value(value_json: Any) -> Any: + """Extract the actual value from value_json.""" + if isinstance(value_json, dict): + return value_json.get("value") + return value_json + + +def _format_date(date: Any) -> str | None: + """Format a date for the policy data structure.""" + if date is None: + return None + if hasattr(date, "isoformat"): + return date.isoformat() + return str(date) + + +# ============================================================================= +# Impact Computation +# ============================================================================= + + +def compute_variable_diff(baseline_val: Any, reform_val: Any) -> dict | None: + """Compute diff for a single variable if both are numeric.""" + if not isinstance(baseline_val, (int, float)): + return None + if not isinstance(reform_val, (int, float)): + return None + + return { + "baseline": baseline_val, + "reform": reform_val, + "change": reform_val - baseline_val, + } + + +def compute_entity_diff(baseline_entity: dict, reform_entity: dict) -> dict: + """Compute per-variable diffs for a single entity instance.""" + entity_diff = {} + + for key, baseline_val in baseline_entity.items(): + reform_val = reform_entity.get(key) + if reform_val is None: + continue + + diff = compute_variable_diff(baseline_val, reform_val) + if diff is not None: + entity_diff[key] = diff + + return entity_diff + + +def compute_entity_list_diff( + baseline_list: list[dict], + reform_list: list[dict], +) -> list[dict]: + """Compute diffs for a list of entity instances.""" + return [ + compute_entity_diff(b_entity, r_entity) + for b_entity, r_entity in zip(baseline_list, reform_list) + ] + + +def compute_household_impact( + baseline_result: dict, + reform_result: dict, + config: CountryConfig, +) -> dict[str, Any]: + """Compute difference between baseline and reform for all entity types.""" + impact: dict[str, Any] = {} + + for entity in config.entity_types: + baseline_entities = baseline_result.get(entity) + reform_entities = reform_result.get(entity) + + if baseline_entities is None or reform_entities is None: + continue + + impact[entity] = compute_entity_list_diff(baseline_entities, reform_entities) + + return impact + + +# ============================================================================= +# Simulation Execution +# ============================================================================= + + +def mark_simulation_running(simulation: Simulation, session: Session) -> None: + """Mark a simulation as running.""" + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def mark_simulation_completed( + simulation: Simulation, + result: dict, + session: Session, +) -> None: + """Mark a simulation as completed with result.""" + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def mark_simulation_failed( + simulation: Simulation, + error: Exception, + session: Session, +) -> None: + """Mark a simulation as failed with error.""" + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(error) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def run_household_simulation(simulation_id: UUID, session: Session) -> None: + """Run a single household simulation and store result.""" + simulation = _load_simulation(simulation_id, session) + household = _load_household(simulation.household_id, session) + policy_data = _load_policy_data(simulation.policy_id, session) + + mark_simulation_running(simulation, session) + + try: + calculator = get_calculator(household.tax_benefit_model_name) + result = calculator(household.household_data, household.year, policy_data) + mark_simulation_completed(simulation, result, session) + except Exception as e: + mark_simulation_failed(simulation, e, session) + + +def _load_simulation(simulation_id: UUID, session: Session) -> Simulation: + """Load simulation or raise error.""" + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise ValueError(f"Simulation {simulation_id} not found") + return simulation + + +def _load_household(household_id: UUID | None, session: Session) -> Household: + """Load household or raise error.""" + if not household_id: + raise ValueError("Simulation has no household_id") + + household = session.get(Household, household_id) + if not household: + raise ValueError(f"Household {household_id} not found") + return household + + +def _load_policy_data(policy_id: UUID | None, session: Session) -> dict | None: + """Load and extract policy data if policy_id is set.""" + if not policy_id: + return None + + policy = session.get(Policy, policy_id) + return _extract_policy_data(policy) + + +# ============================================================================= +# Report Orchestration (Async) +# ============================================================================= + + +def _run_local_household_impact(report_id: str, session: Session) -> None: + """Run household impact analysis locally. + + NOTE: This runs synchronously and blocks the HTTP request when running + locally (agent_use_modal=False). This mirrors the economic impact behavior. + True async execution requires Modal. + """ + report = session.get(Report, report_id) + if not report: + return + + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + try: + # Run baseline simulation + if report.baseline_simulation_id: + _run_simulation_in_session(report.baseline_simulation_id, session) + + # Run reform simulation if present + if report.reform_simulation_id: + _run_simulation_in_session(report.reform_simulation_id, session) + + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + except Exception as e: + report.status = ReportStatus.FAILED + report.error_message = str(e) + session.add(report) + session.commit() + + +def _run_simulation_in_session(simulation_id: UUID, session: Session) -> None: + """Run a single household simulation within an existing session.""" + simulation = session.get(Simulation, simulation_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + return + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError(f"Household {simulation.household_id} not found") + + policy_data = _load_policy_data(simulation.policy_id, session) + + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + calculator = get_calculator(household.tax_benefit_model_name) + result = calculator(household.household_data, household.year, policy_data) + + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + +def _trigger_household_impact( + report_id: str, tax_benefit_model_name: str, session: Session | None = None +) -> None: + """Trigger household impact calculation (local or Modal based on settings).""" + from policyengine_api.config import settings + + traceparent = get_traceparent() + + if not settings.agent_use_modal and session is not None: + # Run locally (blocking - see _run_local_household_impact docstring) + _run_local_household_impact(report_id, session) + else: + # Use Modal + import modal + + if tax_benefit_model_name == "policyengine_uk": + fn = modal.Function.from_name("policyengine", "household_impact_uk") + else: + fn = modal.Function.from_name("policyengine", "household_impact_us") + + fn.spawn(report_id=report_id, traceparent=traceparent) + + +# Legacy functions kept for compatibility +def _load_report(report_id: UUID, session: Session) -> Report: + """Load report or raise error.""" + report = session.get(Report, report_id) + if not report: + raise ValueError(f"Report {report_id} not found") + return report + + +# ============================================================================= +# Request/Response Schemas +# ============================================================================= + + +class HouseholdImpactRequest(BaseModel): + """Request for household impact analysis.""" + + household_id: UUID = Field(description="ID of the household to analyze") + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs single calculation under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, + description="Optional behavioural response specification ID", + ) + + +class HouseholdSimulationInfo(BaseModel): + """Info about a household simulation.""" + + id: UUID + status: SimulationStatus + error_message: str | None = None + + +class HouseholdImpactResponse(BaseModel): + """Response for household impact analysis.""" + + report_id: UUID + report_type: str + status: ReportStatus + baseline_simulation: HouseholdSimulationInfo | None = None + reform_simulation: HouseholdSimulationInfo | None = None + baseline_result: dict | None = None + reform_result: dict | None = None + impact: dict | None = None + error_message: str | None = None + + +# ============================================================================= +# Response Building +# ============================================================================= + + +def build_simulation_info( + simulation: Simulation | None, +) -> HouseholdSimulationInfo | None: + """Build simulation info from a simulation.""" + if not simulation: + return None + + return HouseholdSimulationInfo( + id=simulation.id, + status=simulation.status, + error_message=simulation.error_message, + ) + + +def build_household_response( + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation | None, + session: Session, +) -> HouseholdImpactResponse: + """Build response including computed impact for comparisons.""" + baseline_result = baseline_sim.household_result + reform_result = reform_sim.household_result if reform_sim else None + + impact = _compute_impact_if_comparison( + baseline_sim, reform_sim, baseline_result, reform_result, session + ) + + return HouseholdImpactResponse( + report_id=report.id, + report_type=report.report_type or "household_single", + status=report.status, + baseline_simulation=build_simulation_info(baseline_sim), + reform_simulation=build_simulation_info(reform_sim), + baseline_result=baseline_result, + reform_result=reform_result, + impact=impact, + error_message=report.error_message, + ) + + +def _compute_impact_if_comparison( + baseline_sim: Simulation, + reform_sim: Simulation | None, + baseline_result: dict | None, + reform_result: dict | None, + session: Session, +) -> dict | None: + """Compute impact only if this is a comparison with both results.""" + if not reform_sim: + return None + if not baseline_result or not reform_result: + return None + + household = session.get(Household, baseline_sim.household_id) + if not household: + return None + + config = get_country_config(household.tax_benefit_model_name) + return compute_household_impact(baseline_result, reform_result, config) + + +# ============================================================================= +# Validation Helpers +# ============================================================================= + + +def validate_household_exists(household_id: UUID, session: Session) -> Household: + """Validate household exists and return it.""" + household = session.get(Household, household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {household_id} not found", + ) + return household + + +def validate_policy_exists(policy_id: UUID | None, session: Session) -> None: + """Validate policy exists if provided.""" + if not policy_id: + return + + policy = session.get(Policy, policy_id) + if not policy: + raise HTTPException( + status_code=404, + detail=f"Policy {policy_id} not found", + ) + + +# ============================================================================= +# Endpoints +# ============================================================================= + + +@router.post("/household-impact", response_model=HouseholdImpactResponse) +def household_impact( + request: HouseholdImpactRequest, + session: Session = Depends(get_session), +) -> HouseholdImpactResponse: + """Run household impact analysis. + + If policy_id is None: single run under current law. + If policy_id is set: comparison (baseline vs reform). + + This is an async operation. The endpoint returns immediately with a report_id + and status="pending". Poll GET /analysis/household-impact/{report_id} until + status="completed" to get results. + """ + household = validate_household_exists(request.household_id, session) + validate_policy_exists(request.policy_id, session) + + model_version = _get_model_version(household.tax_benefit_model_name, session) + + baseline_sim = _create_baseline_simulation( + household, model_version.id, request.dynamic_id, session + ) + reform_sim = _create_reform_simulation( + household, model_version.id, request.policy_id, request.dynamic_id, session + ) + + report_type = "household_comparison" if request.policy_id else "household_single" + report = _get_or_create_report( + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id if reform_sim else None, + label=f"Household impact: {household.tax_benefit_model_name}", + report_type=report_type, + session=session, + ) + + if report.status == ReportStatus.PENDING: + with logfire.span("trigger_household_impact", job_id=str(report.id)): + _trigger_household_impact( + str(report.id), household.tax_benefit_model_name, session + ) + + return build_household_response(report, baseline_sim, reform_sim, session) + + +@router.get("/household-impact/{report_id}", response_model=HouseholdImpactResponse) +def get_household_impact( + report_id: UUID, + session: Session = Depends(get_session), +) -> HouseholdImpactResponse: + """Get household impact analysis status and results.""" + report = session.get(Report, report_id) + if not report: + raise HTTPException(status_code=404, detail=f"Report {report_id} not found") + + if not report.baseline_simulation_id: + raise HTTPException( + status_code=500, + detail="Report missing baseline simulation ID", + ) + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + if not baseline_sim: + raise HTTPException(status_code=500, detail="Baseline simulation data missing") + + reform_sim = None + if report.reform_simulation_id: + reform_sim = session.get(Simulation, report.reform_simulation_id) + + return build_household_response(report, baseline_sim, reform_sim, session) + + +# ============================================================================= +# Simulation Creation Helpers +# ============================================================================= + + +def _create_baseline_simulation( + household: Household, + model_version_id: UUID, + dynamic_id: UUID | None, + session: Session, +) -> Simulation: + """Create baseline simulation (current law, no policy).""" + return _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version_id, + policy_id=None, + dynamic_id=dynamic_id, + session=session, + household_id=household.id, + ) + + +def _create_reform_simulation( + household: Household, + model_version_id: UUID, + policy_id: UUID | None, + dynamic_id: UUID | None, + session: Session, +) -> Simulation | None: + """Create reform simulation if policy_id is provided.""" + if not policy_id: + return None + + return _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version_id, + policy_id=policy_id, + dynamic_id=dynamic_id, + session=session, + household_id=household.id, + ) diff --git a/src/policyengine_api/api/households.py b/src/policyengine_api/api/households.py new file mode 100644 index 0000000..fdee1f7 --- /dev/null +++ b/src/policyengine_api/api/households.py @@ -0,0 +1,119 @@ +"""Stored household CRUD endpoints. + +Households represent saved household definitions that can be reused across +calculations and impact analyses. Create a household once, then reference +it by ID for repeated simulations. + +These endpoints manage stored household *definitions* (people, entity groups, +model name, year). For running calculations on a household, use the +/household/calculate and /household/impact endpoints instead. +""" + +from typing import Any +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.models import Household, HouseholdCreate, HouseholdRead +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/households", tags=["households"]) + +_ENTITY_GROUP_KEYS = ( + "tax_unit", + "family", + "spm_unit", + "marital_unit", + "household", + "benunit", +) + + +def _pack_household_data(body: HouseholdCreate) -> dict[str, Any]: + """Pack the flat request fields into a single JSON blob for storage.""" + data: dict[str, Any] = {"people": body.people} + for key in _ENTITY_GROUP_KEYS: + val = getattr(body, key) + if val is not None: + data[key] = val + return data + + +def _to_read(record: Household) -> HouseholdRead: + """Unpack the JSON blob back into the flat response shape.""" + data = record.household_data + return HouseholdRead( + id=record.id, + tax_benefit_model_name=record.tax_benefit_model_name, + year=record.year, + label=record.label, + people=data["people"], + tax_unit=data.get("tax_unit"), + family=data.get("family"), + spm_unit=data.get("spm_unit"), + marital_unit=data.get("marital_unit"), + household=data.get("household"), + benunit=data.get("benunit"), + created_at=record.created_at, + updated_at=record.updated_at, + ) + + +@router.post("/", response_model=HouseholdRead, status_code=201) +def create_household(body: HouseholdCreate, session: Session = Depends(get_session)): + """Create a stored household definition. + + The household data (people + entity groups) is persisted so it can be + retrieved later by ID. Use the returned ID with /household/calculate + or /household/impact to run simulations. + """ + record = Household( + tax_benefit_model_name=body.tax_benefit_model_name, + year=body.year, + label=body.label, + household_data=_pack_household_data(body), + ) + session.add(record) + session.commit() + session.refresh(record) + return _to_read(record) + + +@router.get("/", response_model=list[HouseholdRead]) +def list_households( + tax_benefit_model_name: str | None = None, + limit: int = Query(default=50, le=200), + offset: int = Query(default=0, ge=0), + session: Session = Depends(get_session), +): + """List stored households with optional filtering.""" + query = select(Household) + if tax_benefit_model_name is not None: + query = query.where(Household.tax_benefit_model_name == tax_benefit_model_name) + query = query.offset(offset).limit(limit) + records = session.exec(query).all() + return [_to_read(r) for r in records] + + +@router.get("/{household_id}", response_model=HouseholdRead) +def get_household(household_id: UUID, session: Session = Depends(get_session)): + """Get a stored household by ID.""" + record = session.get(Household, household_id) + if not record: + raise HTTPException( + status_code=404, detail=f"Household {household_id} not found" + ) + return _to_read(record) + + +@router.delete("/{household_id}", status_code=204) +def delete_household(household_id: UUID, session: Session = Depends(get_session)): + """Delete a stored household.""" + record = session.get(Household, household_id) + if not record: + raise HTTPException( + status_code=404, detail=f"Household {household_id} not found" + ) + session.delete(record) + session.commit() diff --git a/src/policyengine_api/api/user_household_associations.py b/src/policyengine_api/api/user_household_associations.py new file mode 100644 index 0000000..fa40e06 --- /dev/null +++ b/src/policyengine_api/api/user_household_associations.py @@ -0,0 +1,125 @@ +"""User-household association endpoints. + +Associations link a user to a stored household definition with metadata +(label, country). A user can have multiple associations to the same +household (e.g. different labels or configurations). +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.models import ( + Household, + UserHouseholdAssociation, + UserHouseholdAssociationCreate, + UserHouseholdAssociationRead, + UserHouseholdAssociationUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter( + prefix="/user-household-associations", + tags=["user-household-associations"], +) + + +@router.post("/", response_model=UserHouseholdAssociationRead, status_code=201) +def create_association( + body: UserHouseholdAssociationCreate, + session: Session = Depends(get_session), +): + """Create a user-household association.""" + household = session.get(Household, body.household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {body.household_id} not found", + ) + + record = UserHouseholdAssociation( + user_id=body.user_id, + household_id=body.household_id, + country_id=body.country_id, + label=body.label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.get("/user/{user_id}", response_model=list[UserHouseholdAssociationRead]) +def list_by_user( + user_id: UUID, + country_id: str | None = None, + limit: int = Query(default=50, le=200), + offset: int = Query(default=0, ge=0), + session: Session = Depends(get_session), +): + """List all associations for a user, optionally filtered by country.""" + query = select(UserHouseholdAssociation).where( + UserHouseholdAssociation.user_id == user_id + ) + if country_id is not None: + query = query.where(UserHouseholdAssociation.country_id == country_id) + query = query.offset(offset).limit(limit) + return session.exec(query).all() + + +@router.get( + "/{user_id}/{household_id}", + response_model=list[UserHouseholdAssociationRead], +) +def list_by_user_and_household( + user_id: UUID, + household_id: UUID, + session: Session = Depends(get_session), +): + """List all associations for a specific user+household pair.""" + query = select(UserHouseholdAssociation).where( + UserHouseholdAssociation.user_id == user_id, + UserHouseholdAssociation.household_id == household_id, + ) + return session.exec(query).all() + + +@router.put("/{association_id}", response_model=UserHouseholdAssociationRead) +def update_association( + association_id: UUID, + body: UserHouseholdAssociationUpdate, + session: Session = Depends(get_session), +): + """Update a user-household association (label).""" + record = session.get(UserHouseholdAssociation, association_id) + if not record: + raise HTTPException( + status_code=404, + detail=f"Association {association_id} not found", + ) + update_data = body.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(record, key, value) + record.updated_at = datetime.now(timezone.utc) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.delete("/{association_id}", status_code=204) +def delete_association( + association_id: UUID, + session: Session = Depends(get_session), +): + """Delete a user-household association.""" + record = session.get(UserHouseholdAssociation, association_id) + if not record: + raise HTTPException( + status_code=404, + detail=f"Association {association_id} not found", + ) + session.delete(record) + session.commit() diff --git a/src/policyengine_api/config/settings.py b/src/policyengine_api/config/settings.py index 76a1ab1..efba345 100644 --- a/src/policyengine_api/config/settings.py +++ b/src/policyengine_api/config/settings.py @@ -40,10 +40,21 @@ class Settings(BaseSettings): @property def database_url(self) -> str: - """Get database URL from Supabase.""" + """Get database URL from Supabase. + + For local development, the database runs on port 54322 (not 54321 which is the API). + Use supabase_db_url to override, or rely on the default local URL. + """ + if self.supabase_db_url: + return self.supabase_db_url + + # For local development, default to the standard Supabase local DB port + if "localhost" in self.supabase_url or "127.0.0.1" in self.supabase_url: + return "postgresql://postgres:postgres@127.0.0.1:54322/postgres" + + # For remote Supabase, construct URL from API URL (usually need supabase_db_url set) return ( - self.supabase_db_url - or self.supabase_url.replace( + self.supabase_url.replace( "http://", "postgresql://postgres:postgres@" ).replace("https://", "postgresql://postgres:postgres@") + "/postgres" diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 1aa8119..2b486f3 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -7,7 +7,8 @@ Function naming follows the API hierarchy: - simulate_household_*: Single household calculation (/simulate/household) - simulate_economy_*: Single economy simulation (/simulate/economy) -- economy_comparison_*: Full economy comparison analysis (/analysis/compare/economy) +- economy_comparison_*: Full economy comparison analysis (/analysis/economic-impact) +- household_impact_*: Household impact analysis (/analysis/household-impact) Deploy with: modal deploy src/policyengine_api/modal_app.py """ @@ -806,7 +807,6 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N raise ValueError(f"Dataset {simulation.dataset_id} not found") # Import policyengine - from policyengine.core import Simulation as PESimulation from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import ( PolicyEngineUKDataset, @@ -814,7 +814,7 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N pe_model_version = uk_latest - # Get policy and dynamic + # Get policy and dynamic as PEPolicy/PEDynamic objects policy = _get_pe_policy_uk( simulation.policy_id, pe_model_version, session ) @@ -822,6 +822,13 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N simulation.dynamic_id, pe_model_version, session ) + # Convert to reform dict format for Microsimulation + # This is necessary because policyengine.core.Simulation applies + # reforms AFTER creating Microsimulation, which doesn't work + policy_reform = _pe_policy_to_reform_dict(policy) + dynamic_reform = _pe_policy_to_reform_dict(dynamic) + reform = _merge_reform_dicts(policy_reform, dynamic_reform) + # Download dataset local_path = download_dataset( dataset.filepath, supabase_url, supabase_key, storage_bucket @@ -834,15 +841,12 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N year=dataset.year, ) - # Create and run simulation + # Run simulation using Microsimulation directly with reform + # This ensures reforms are applied at construction time with logfire.span("run_simulation"): - pe_sim = PESimulation( - dataset=pe_dataset, - tax_benefit_model_version=pe_model_version, - policy=policy, - dynamic=dynamic, + pe_output_dataset = _run_uk_economy_simulation( + pe_dataset, reform, pe_model_version, simulation_id ) - pe_sim.ensure() # Save output dataset with logfire.span("save_output_dataset"): @@ -852,8 +856,8 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N output_path = f"/tmp/{output_filename}" # Set filepath and save - pe_sim.output_dataset.filepath = output_path - pe_sim.output_dataset.save() + pe_output_dataset.filepath = output_path + pe_output_dataset.save() # Upload to Supabase storage supabase = create_client(supabase_url, supabase_key) @@ -868,7 +872,7 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N ) # Create output dataset record - output_dataset = Dataset( + output_dataset_record = Dataset( name=f"Output: {dataset.name}", description=f"Output from simulation {simulation_id}", filepath=output_filename, @@ -876,12 +880,12 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, ) - session.add(output_dataset) + session.add(output_dataset_record) session.commit() - session.refresh(output_dataset) + session.refresh(output_dataset_record) # Link to simulation - simulation.output_dataset_id = output_dataset.id + simulation.output_dataset_id = output_dataset_record.id # Mark as completed simulation.status = SimulationStatus.COMPLETED @@ -972,15 +976,15 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N raise ValueError(f"Dataset {simulation.dataset_id} not found") # Import policyengine - from policyengine.core import Simulation as PESimulation from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import ( PolicyEngineUSDataset, + USYearData, ) pe_model_version = us_latest - # Get policy and dynamic + # Get policy and dynamic as PEPolicy/PEDynamic objects policy = _get_pe_policy_us( simulation.policy_id, pe_model_version, session ) @@ -988,6 +992,13 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N simulation.dynamic_id, pe_model_version, session ) + # Convert to reform dict format for Microsimulation + # This is necessary because policyengine.core.Simulation applies + # reforms AFTER creating Microsimulation, which doesn't work + policy_reform = _pe_policy_to_reform_dict(policy) + dynamic_reform = _pe_policy_to_reform_dict(dynamic) + reform = _merge_reform_dicts(policy_reform, dynamic_reform) + # Download dataset local_path = download_dataset( dataset.filepath, supabase_url, supabase_key, storage_bucket @@ -1000,15 +1011,12 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N year=dataset.year, ) - # Create and run simulation + # Run simulation using Microsimulation directly with reform + # This ensures reforms are applied at construction time with logfire.span("run_simulation"): - pe_sim = PESimulation( - dataset=pe_dataset, - tax_benefit_model_version=pe_model_version, - policy=policy, - dynamic=dynamic, + pe_output_dataset = _run_us_economy_simulation( + pe_dataset, reform, pe_model_version, simulation_id ) - pe_sim.ensure() # Save output dataset with logfire.span("save_output_dataset"): @@ -1018,8 +1026,8 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N output_path = f"/tmp/{output_filename}" # Set filepath and save - pe_sim.output_dataset.filepath = output_path - pe_sim.output_dataset.save() + pe_output_dataset.filepath = output_path + pe_output_dataset.save() # Upload to Supabase storage supabase = create_client(supabase_url, supabase_key) @@ -1034,7 +1042,7 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N ) # Create output dataset record - output_dataset = Dataset( + output_dataset_record = Dataset( name=f"Output: {dataset.name}", description=f"Output from simulation {simulation_id}", filepath=output_filename, @@ -1042,12 +1050,12 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, ) - session.add(output_dataset) + session.add(output_dataset_record) session.commit() - session.refresh(output_dataset) + session.refresh(output_dataset_record) # Link to simulation - simulation.output_dataset_id = output_dataset.id + simulation.output_dataset_id = output_dataset_record.id # Mark as completed simulation.status = SimulationStatus.COMPLETED @@ -1815,6 +1823,403 @@ def _get_pe_dynamic_us(dynamic_id, model_version, session): return _get_pe_dynamic_uk(dynamic_id, model_version, session) +def _pe_policy_to_reform_dict(policy) -> dict | None: + """Convert a policyengine.core.policy.Policy to reform dict format. + + The policyengine-us/uk Microsimulation expects reforms in the format: + {"parameter.name": {"YYYY-MM-DD": value}} + + This is necessary because the policyengine.core.Simulation applies reforms + AFTER creating the Microsimulation, which doesn't work due to caching. + We need to pass the reform at Microsimulation construction time. + """ + if policy is None: + return None + + if not policy.parameter_values: + return None + + reform = {} + for pv in policy.parameter_values: + if not pv.parameter: + continue + param_name = pv.parameter.name + value = pv.value + start_date = pv.start_date + + if param_name and start_date: + # Format date as YYYY-MM-DD string + if hasattr(start_date, "strftime"): + date_str = start_date.strftime("%Y-%m-%d") + else: + date_str = str(start_date).split("T")[0] + + if param_name not in reform: + reform[param_name] = {} + reform[param_name][date_str] = value + + return reform if reform else None + + +def _merge_reform_dicts(reform1: dict | None, reform2: dict | None) -> dict | None: + """Merge two reform dicts, with reform2 taking precedence.""" + if reform1 is None and reform2 is None: + return None + if reform1 is None: + return reform2 + if reform2 is None: + return reform1 + + merged = dict(reform1) + for param_name, dates in reform2.items(): + if param_name not in merged: + merged[param_name] = {} + merged[param_name].update(dates) + return merged + + +def _run_us_economy_simulation(pe_dataset, reform, pe_model_version, simulation_id): + """Run US economy simulation using Microsimulation directly. + + This bypasses policyengine.core.Simulation which has a bug where reforms + are applied AFTER creating Microsimulation (when it's too late). + Instead, we pass the reform dict at Microsimulation construction time. + """ + from pathlib import Path + + import numpy as np + import pandas as pd + from microdf import MicroDataFrame + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, + ) + from policyengine_core.simulations.simulation_builder import SimulationBuilder + from policyengine_us import Microsimulation + from policyengine_us.system import system + + # Load dataset + pe_dataset.load() + year = pe_dataset.year + + # Create Microsimulation with reform applied at construction time + microsim = Microsimulation(reform=reform) + + # Build simulation from dataset using SimulationBuilder + person_df = pd.DataFrame(pe_dataset.data.person) + + # Determine column naming convention + household_id_col = ( + "person_household_id" + if "person_household_id" in person_df.columns + else "household_id" + ) + marital_unit_id_col = ( + "person_marital_unit_id" + if "person_marital_unit_id" in person_df.columns + else "marital_unit_id" + ) + family_id_col = ( + "person_family_id" if "person_family_id" in person_df.columns else "family_id" + ) + spm_unit_id_col = ( + "person_spm_unit_id" + if "person_spm_unit_id" in person_df.columns + else "spm_unit_id" + ) + tax_unit_id_col = ( + "person_tax_unit_id" + if "person_tax_unit_id" in person_df.columns + else "tax_unit_id" + ) + + # Declare entities + builder = SimulationBuilder() + builder.populations = system.instantiate_entities() + + builder.declare_person_entity("person", person_df["person_id"].values) + builder.declare_entity("household", np.unique(person_df[household_id_col].values)) + builder.declare_entity("spm_unit", np.unique(person_df[spm_unit_id_col].values)) + builder.declare_entity("family", np.unique(person_df[family_id_col].values)) + builder.declare_entity("tax_unit", np.unique(person_df[tax_unit_id_col].values)) + builder.declare_entity( + "marital_unit", np.unique(person_df[marital_unit_id_col].values) + ) + + # Join persons to entities + builder.join_with_persons( + builder.populations["household"], + person_df[household_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["spm_unit"], + person_df[spm_unit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["family"], + person_df[family_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["tax_unit"], + person_df[tax_unit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["marital_unit"], + person_df[marital_unit_id_col].values, + np.array(["member"] * len(person_df)), + ) + + microsim.build_from_populations(builder.populations) + + # Set input variables + id_columns = { + "person_id", + "household_id", + "person_household_id", + "spm_unit_id", + "person_spm_unit_id", + "family_id", + "person_family_id", + "tax_unit_id", + "person_tax_unit_id", + "marital_unit_id", + "person_marital_unit_id", + } + + for entity_name, entity_data in [ + ("person", pe_dataset.data.person), + ("household", pe_dataset.data.household), + ("spm_unit", pe_dataset.data.spm_unit), + ("family", pe_dataset.data.family), + ("tax_unit", pe_dataset.data.tax_unit), + ("marital_unit", pe_dataset.data.marital_unit), + ]: + df = pd.DataFrame(entity_data) + for column in df.columns: + if column not in id_columns and column in system.variables: + microsim.set_input(column, year, df[column].values) + + # Calculate output variables and build output dataset + data = { + "person": pd.DataFrame(), + "marital_unit": pd.DataFrame(), + "family": pd.DataFrame(), + "spm_unit": pd.DataFrame(), + "tax_unit": pd.DataFrame(), + "household": pd.DataFrame(), + } + + weight_columns = { + "person_weight", + "household_weight", + "marital_unit_weight", + "family_weight", + "spm_unit_weight", + "tax_unit_weight", + } + + # Copy ID and weight columns from input dataset + for entity in data.keys(): + input_df = pd.DataFrame(getattr(pe_dataset.data, entity)) + entity_id_col = f"{entity}_id" + entity_weight_col = f"{entity}_weight" + + if entity_id_col in input_df.columns: + data[entity][entity_id_col] = input_df[entity_id_col].values + if entity_weight_col in input_df.columns: + data[entity][entity_weight_col] = input_df[entity_weight_col].values + + # Copy person-level group ID columns + for col in person_df.columns: + if col.startswith("person_") and col.endswith("_id"): + target_col = col.replace("person_", "") + if target_col in id_columns: + data["person"][target_col] = person_df[col].values + + # Calculate non-ID, non-weight variables + for entity, variables in pe_model_version.entity_variables.items(): + for var in variables: + if var not in id_columns and var not in weight_columns: + data[entity][var] = microsim.calculate( + var, period=year, map_to=entity + ).values + + # Convert to MicroDataFrames + data["person"] = MicroDataFrame(data["person"], weights="person_weight") + data["marital_unit"] = MicroDataFrame( + data["marital_unit"], weights="marital_unit_weight" + ) + data["family"] = MicroDataFrame(data["family"], weights="family_weight") + data["spm_unit"] = MicroDataFrame(data["spm_unit"], weights="spm_unit_weight") + data["tax_unit"] = MicroDataFrame(data["tax_unit"], weights="tax_unit_weight") + data["household"] = MicroDataFrame(data["household"], weights="household_weight") + + # Create output dataset + return PolicyEngineUSDataset( + id=simulation_id, + name=pe_dataset.name, + description=pe_dataset.description, + filepath=str(Path(pe_dataset.filepath).parent / (simulation_id + ".h5")), + year=year, + is_output_dataset=True, + data=USYearData( + person=data["person"], + marital_unit=data["marital_unit"], + family=data["family"], + spm_unit=data["spm_unit"], + tax_unit=data["tax_unit"], + household=data["household"], + ), + ) + + +def _run_uk_economy_simulation(pe_dataset, reform, pe_model_version, simulation_id): + """Run UK economy simulation using Microsimulation directly. + + This bypasses policyengine.core.Simulation which has a bug where reforms + are applied AFTER creating Microsimulation (when it's too late). + Instead, we pass the reform dict at Microsimulation construction time. + """ + from pathlib import Path + + import numpy as np + import pandas as pd + from microdf import MicroDataFrame + from policyengine.tax_benefit_models.uk.datasets import ( + PolicyEngineUKDataset, + UKYearData, + ) + from policyengine_core.simulations.simulation_builder import SimulationBuilder + from policyengine_uk import Microsimulation + from policyengine_uk.system import system + + # Load dataset + pe_dataset.load() + year = pe_dataset.year + + # Create Microsimulation with reform applied at construction time + microsim = Microsimulation(reform=reform) + + # Build simulation from dataset using SimulationBuilder + person_df = pd.DataFrame(pe_dataset.data.person) + + # Determine column naming convention + benunit_id_col = ( + "person_benunit_id" + if "person_benunit_id" in person_df.columns + else "benunit_id" + ) + household_id_col = ( + "person_household_id" + if "person_household_id" in person_df.columns + else "household_id" + ) + + # Declare entities + builder = SimulationBuilder() + builder.populations = system.instantiate_entities() + + builder.declare_person_entity("person", person_df["person_id"].values) + builder.declare_entity("benunit", np.unique(person_df[benunit_id_col].values)) + builder.declare_entity("household", np.unique(person_df[household_id_col].values)) + + # Join persons to entities + builder.join_with_persons( + builder.populations["benunit"], + person_df[benunit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["household"], + person_df[household_id_col].values, + np.array(["member"] * len(person_df)), + ) + + microsim.build_from_populations(builder.populations) + + # Set input variables + id_columns = { + "person_id", + "benunit_id", + "person_benunit_id", + "household_id", + "person_household_id", + } + + for entity_name, entity_data in [ + ("person", pe_dataset.data.person), + ("benunit", pe_dataset.data.benunit), + ("household", pe_dataset.data.household), + ]: + df = pd.DataFrame(entity_data) + for column in df.columns: + if column not in id_columns and column in system.variables: + microsim.set_input(column, year, df[column].values) + + # Calculate output variables and build output dataset + data = { + "person": pd.DataFrame(), + "benunit": pd.DataFrame(), + "household": pd.DataFrame(), + } + + weight_columns = { + "person_weight", + "benunit_weight", + "household_weight", + } + + # Copy ID and weight columns from input dataset + for entity in data.keys(): + input_df = pd.DataFrame(getattr(pe_dataset.data, entity)) + entity_id_col = f"{entity}_id" + entity_weight_col = f"{entity}_weight" + + if entity_id_col in input_df.columns: + data[entity][entity_id_col] = input_df[entity_id_col].values + if entity_weight_col in input_df.columns: + data[entity][entity_weight_col] = input_df[entity_weight_col].values + + # Copy person-level group ID columns + for col in person_df.columns: + if col.startswith("person_") and col.endswith("_id"): + target_col = col.replace("person_", "") + if target_col in id_columns: + data["person"][target_col] = person_df[col].values + + # Calculate non-ID, non-weight variables + for entity, variables in pe_model_version.entity_variables.items(): + for var in variables: + if var not in id_columns and var not in weight_columns: + data[entity][var] = microsim.calculate( + var, period=year, map_to=entity + ).values + + # Convert to MicroDataFrames + data["person"] = MicroDataFrame(data["person"], weights="person_weight") + data["benunit"] = MicroDataFrame(data["benunit"], weights="benunit_weight") + data["household"] = MicroDataFrame(data["household"], weights="household_weight") + + # Create output dataset + return PolicyEngineUKDataset( + id=simulation_id, + name=pe_dataset.name, + description=pe_dataset.description, + filepath=str(Path(pe_dataset.filepath).parent / (simulation_id + ".h5")), + year=year, + is_output_dataset=True, + data=UKYearData( + person=data["person"], + benunit=data["benunit"], + household=data["household"], + ), + ) + + @app.function( image=uk_image, secrets=[db_secrets, logfire_secrets], @@ -2516,3 +2921,689 @@ def compute_change_aggregate_us( raise finally: logfire.force_flush() + + +# ============================================================================= +# Household Impact Functions +# ============================================================================= + + +@app.function( + image=uk_image, + secrets=[db_secrets, logfire_secrets], + memory=2048, + cpu=2, + timeout=300, +) +def household_impact_uk(report_id: str, traceparent: str | None = None) -> None: + """Run UK household impact analysis and write results to database.""" + import logfire + + configure_logfire("policyengine-modal-uk", traceparent) + + try: + with logfire.span("household_impact_uk", report_id=report_id): + from datetime import datetime, timezone + from uuid import UUID + + from sqlmodel import Session, create_engine + + database_url = get_database_url() + engine = create_engine(database_url) + + try: + from policyengine_api.models import ( + Household, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) + + with Session(engine) as session: + # Load report + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found") + + # Mark as running + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Run baseline simulation + if report.baseline_simulation_id: + _run_household_simulation_uk( + report.baseline_simulation_id, session + ) + + # Run reform simulation if present + if report.reform_simulation_id: + _run_household_simulation_uk( + report.reform_simulation_id, session + ) + + # Mark report as completed + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + except Exception as e: + logfire.error( + "UK household impact failed", report_id=report_id, error=str(e) + ) + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'FAILED', error_message = :error " + "WHERE id = :report_id" + ), + {"report_id": report_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise + finally: + logfire.force_flush() + + +def _run_household_simulation_uk(simulation_id, session) -> None: + """Run a single UK household simulation.""" + from datetime import datetime, timezone + + from policyengine_api.models import ( + Household, + Simulation, + SimulationStatus, + ) + + simulation = session.get(Simulation, simulation_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + return + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError(f"Household {simulation.household_id} not found") + + # Mark as running + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + # Get policy data if present + policy_data = _get_household_policy_data(simulation.policy_id, session) + + # Run calculation + result = _calculate_uk_household( + household.household_data, + household.year, + policy_data, + ) + + # Store result + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + +def _calculate_uk_household( + household_data: dict, year: int, policy_data: dict | None +) -> dict: + """Calculate UK household and return result dict.""" + import tempfile + from pathlib import Path + + import pandas as pd + from microdf import MicroDataFrame + from policyengine.core import Simulation + from policyengine.tax_benefit_models.uk import uk_latest + from policyengine.tax_benefit_models.uk.datasets import ( + PolicyEngineUKDataset, + UKYearData, + ) + + people = household_data.get("people", []) + benunit = household_data.get("benunit", []) + hh = household_data.get("household", []) + + # Ensure lists + if isinstance(benunit, dict): + benunit = [benunit] + if isinstance(hh, dict): + hh = [hh] + + n_people = len(people) + n_benunits = max(1, len(benunit) if benunit else 1) + n_households = max(1, len(hh) if hh else 1) + + # Build person data + person_data = { + "person_id": list(range(n_people)), + "person_benunit_id": [0] * n_people, + "person_household_id": [0] * n_people, + "person_weight": [1.0] * n_people, + } + for i, person in enumerate(people): + for key, value in person.items(): + if key not in person_data: + person_data[key] = [0.0] * n_people + person_data[key][i] = value + + # Build benunit data + benunit_data = { + "benunit_id": list(range(n_benunits)), + "benunit_weight": [1.0] * n_benunits, + } + for i, bu in enumerate(benunit if benunit else [{}]): + for key, value in bu.items(): + if key not in benunit_data: + benunit_data[key] = [0.0] * n_benunits + benunit_data[key][i] = value + + # Build household data + household_df_data = { + "household_id": list(range(n_households)), + "household_weight": [1.0] * n_households, + "region": ["LONDON"] * n_households, + "tenure_type": ["RENT_PRIVATELY"] * n_households, + "council_tax": [0.0] * n_households, + "rent": [0.0] * n_households, + } + for i, h in enumerate(hh if hh else [{}]): + for key, value in h.items(): + if key not in household_df_data: + household_df_data[key] = [0.0] * n_households + household_df_data[key][i] = value + + # Create MicroDataFrames + person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") + benunit_df = MicroDataFrame(pd.DataFrame(benunit_data), weights="benunit_weight") + household_df = MicroDataFrame( + pd.DataFrame(household_df_data), weights="household_weight" + ) + + # Create temporary dataset + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "household_calc.h5") + + dataset = PolicyEngineUKDataset( + name="Household calculation", + description="Household(s) for calculation", + person=person_df, + benunit=benunit_df, + household=household_df, + filepath=filepath, + year_data_class=UKYearData, + ) + dataset.save() + + # Build policy if provided + policy = None + if policy_data: + from policyengine.core.policy import ParameterValue, Policy + + pe_param_values = [] + param_lookup = {p.name: p for p in uk_latest.parameters} + for pv in policy_data.get("parameter_values", []): + param_name = pv.get("parameter_name") + if param_name and param_name in param_lookup: + pe_pv = ParameterValue( + parameter=param_lookup[param_name], + value=pv.get("value"), + start_date=pv.get("start_date"), + end_date=pv.get("end_date"), + ) + pe_param_values.append(pe_pv) + + if pe_param_values: + policy = Policy( + name=policy_data.get("name", "Reform"), + description=policy_data.get("description", ""), + parameter_values=pe_param_values, + ) + + # Run simulation + sim = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + policy=policy, + ) + sim.ensure() + + # Extract results + result = {"person": [], "benunit": [], "household": []} + + for i in range(n_people): + person_result = {} + for var in sim.output_dataset.person.columns: + val = sim.output_dataset.person[var].iloc[i] + person_result[var] = float(val) if hasattr(val, "item") else val + result["person"].append(person_result) + + for i in range(n_benunits): + benunit_result = {} + for var in sim.output_dataset.benunit.columns: + val = sim.output_dataset.benunit[var].iloc[i] + benunit_result[var] = float(val) if hasattr(val, "item") else val + result["benunit"].append(benunit_result) + + for i in range(n_households): + household_result = {} + for var in sim.output_dataset.household.columns: + val = sim.output_dataset.household[var].iloc[i] + household_result[var] = float(val) if hasattr(val, "item") else val + result["household"].append(household_result) + + return result + + +@app.function( + image=us_image, + secrets=[db_secrets, logfire_secrets], + memory=2048, + cpu=2, + timeout=300, +) +def household_impact_us(report_id: str, traceparent: str | None = None) -> None: + """Run US household impact analysis and write results to database.""" + import logfire + + configure_logfire("policyengine-modal-us", traceparent) + + try: + with logfire.span("household_impact_us", report_id=report_id): + from datetime import datetime, timezone + from uuid import UUID + + from sqlmodel import Session, create_engine + + database_url = get_database_url() + engine = create_engine(database_url) + + try: + from policyengine_api.models import ( + Household, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) + + with Session(engine) as session: + # Load report + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found") + + # Mark as running + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Run baseline simulation + if report.baseline_simulation_id: + _run_household_simulation_us( + report.baseline_simulation_id, session + ) + + # Run reform simulation if present + if report.reform_simulation_id: + _run_household_simulation_us( + report.reform_simulation_id, session + ) + + # Mark report as completed + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + except Exception as e: + logfire.error( + "US household impact failed", report_id=report_id, error=str(e) + ) + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'FAILED', error_message = :error " + "WHERE id = :report_id" + ), + {"report_id": report_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise + finally: + logfire.force_flush() + + +def _run_household_simulation_us(simulation_id, session) -> None: + """Run a single US household simulation.""" + from datetime import datetime, timezone + + from policyengine_api.models import ( + Household, + Simulation, + SimulationStatus, + ) + + simulation = session.get(Simulation, simulation_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + return + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError(f"Household {simulation.household_id} not found") + + # Mark as running + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + # Get policy data if present + policy_data = _get_household_policy_data(simulation.policy_id, session) + + # Run calculation + result = _calculate_us_household( + household.household_data, + household.year, + policy_data, + ) + + # Store result + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + +def _calculate_us_household( + household_data: dict, year: int, policy_data: dict | None +) -> dict: + """Calculate US household and return result dict.""" + import tempfile + from pathlib import Path + + import pandas as pd + from microdf import MicroDataFrame + from policyengine.core import Simulation + from policyengine.tax_benefit_models.us import us_latest + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, + ) + + people = household_data.get("people", []) + tax_unit = household_data.get("tax_unit", []) + family = household_data.get("family", []) + spm_unit = household_data.get("spm_unit", []) + marital_unit = household_data.get("marital_unit", []) + hh = household_data.get("household", []) + + # Ensure lists + if isinstance(tax_unit, dict): + tax_unit = [tax_unit] + if isinstance(family, dict): + family = [family] + if isinstance(spm_unit, dict): + spm_unit = [spm_unit] + if isinstance(marital_unit, dict): + marital_unit = [marital_unit] + if isinstance(hh, dict): + hh = [hh] + + n_people = len(people) + n_tax_units = max(1, len(tax_unit) if tax_unit else 1) + n_families = max(1, len(family) if family else 1) + n_spm_units = max(1, len(spm_unit) if spm_unit else 1) + n_marital_units = max(1, len(marital_unit) if marital_unit else 1) + n_households = max(1, len(hh) if hh else 1) + + # Build person data + person_data = { + "person_id": list(range(n_people)), + "person_tax_unit_id": [0] * n_people, + "person_family_id": [0] * n_people, + "person_spm_unit_id": [0] * n_people, + "person_marital_unit_id": [0] * n_people, + "person_household_id": [0] * n_people, + "person_weight": [1.0] * n_people, + } + for i, person in enumerate(people): + for key, value in person.items(): + if key not in person_data: + person_data[key] = [0.0] * n_people + person_data[key][i] = value + + # Build tax_unit data + tax_unit_data = { + "tax_unit_id": list(range(n_tax_units)), + "tax_unit_weight": [1.0] * n_tax_units, + } + for i, tu in enumerate(tax_unit if tax_unit else [{}]): + for key, value in tu.items(): + if key not in tax_unit_data: + tax_unit_data[key] = [0.0] * n_tax_units + tax_unit_data[key][i] = value + + # Build family data + family_data = { + "family_id": list(range(n_families)), + "family_weight": [1.0] * n_families, + } + for i, fam in enumerate(family if family else [{}]): + for key, value in fam.items(): + if key not in family_data: + family_data[key] = [0.0] * n_families + family_data[key][i] = value + + # Build spm_unit data + spm_unit_data = { + "spm_unit_id": list(range(n_spm_units)), + "spm_unit_weight": [1.0] * n_spm_units, + } + for i, spm in enumerate(spm_unit if spm_unit else [{}]): + for key, value in spm.items(): + if key not in spm_unit_data: + spm_unit_data[key] = [0.0] * n_spm_units + spm_unit_data[key][i] = value + + # Build marital_unit data + marital_unit_data = { + "marital_unit_id": list(range(n_marital_units)), + "marital_unit_weight": [1.0] * n_marital_units, + } + for i, mu in enumerate(marital_unit if marital_unit else [{}]): + for key, value in mu.items(): + if key not in marital_unit_data: + marital_unit_data[key] = [0.0] * n_marital_units + marital_unit_data[key][i] = value + + # Build household data + household_df_data = { + "household_id": list(range(n_households)), + "household_weight": [1.0] * n_households, + } + for i, h in enumerate(hh if hh else [{}]): + for key, value in h.items(): + if key not in household_df_data: + household_df_data[key] = [0.0] * n_households + household_df_data[key][i] = value + + # Create MicroDataFrames + person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") + tax_unit_df = MicroDataFrame( + pd.DataFrame(tax_unit_data), weights="tax_unit_weight" + ) + family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight") + spm_unit_df = MicroDataFrame( + pd.DataFrame(spm_unit_data), weights="spm_unit_weight" + ) + marital_unit_df = MicroDataFrame( + pd.DataFrame(marital_unit_data), weights="marital_unit_weight" + ) + household_df = MicroDataFrame( + pd.DataFrame(household_df_data), weights="household_weight" + ) + + # Create temporary dataset + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "household_calc.h5") + + dataset = PolicyEngineUSDataset( + name="Household calculation", + description="Household(s) for calculation", + person=person_df, + tax_unit=tax_unit_df, + family=family_df, + spm_unit=spm_unit_df, + marital_unit=marital_unit_df, + household=household_df, + filepath=filepath, + year_data_class=USYearData, + ) + dataset.save() + + # Build policy if provided + policy = None + if policy_data: + from policyengine.core.policy import ParameterValue, Policy + + pe_param_values = [] + param_lookup = {p.name: p for p in us_latest.parameters} + for pv in policy_data.get("parameter_values", []): + param_name = pv.get("parameter_name") + if param_name and param_name in param_lookup: + pe_pv = ParameterValue( + parameter=param_lookup[param_name], + value=pv.get("value"), + start_date=pv.get("start_date"), + end_date=pv.get("end_date"), + ) + pe_param_values.append(pe_pv) + + if pe_param_values: + policy = Policy( + name=policy_data.get("name", "Reform"), + description=policy_data.get("description", ""), + parameter_values=pe_param_values, + ) + + # Run simulation + sim = Simulation( + dataset=dataset, + tax_benefit_model_version=us_latest, + policy=policy, + ) + sim.ensure() + + # Extract results + result = { + "person": [], + "tax_unit": [], + "family": [], + "spm_unit": [], + "marital_unit": [], + "household": [], + } + + for i in range(n_people): + person_result = {} + for var in sim.output_dataset.person.columns: + val = sim.output_dataset.person[var].iloc[i] + person_result[var] = float(val) if hasattr(val, "item") else val + result["person"].append(person_result) + + for i in range(n_tax_units): + tu_result = {} + for var in sim.output_dataset.tax_unit.columns: + val = sim.output_dataset.tax_unit[var].iloc[i] + tu_result[var] = float(val) if hasattr(val, "item") else val + result["tax_unit"].append(tu_result) + + for i in range(n_families): + fam_result = {} + for var in sim.output_dataset.family.columns: + val = sim.output_dataset.family[var].iloc[i] + fam_result[var] = float(val) if hasattr(val, "item") else val + result["family"].append(fam_result) + + for i in range(n_spm_units): + spm_result = {} + for var in sim.output_dataset.spm_unit.columns: + val = sim.output_dataset.spm_unit[var].iloc[i] + spm_result[var] = float(val) if hasattr(val, "item") else val + result["spm_unit"].append(spm_result) + + for i in range(n_marital_units): + mu_result = {} + for var in sim.output_dataset.marital_unit.columns: + val = sim.output_dataset.marital_unit[var].iloc[i] + mu_result[var] = float(val) if hasattr(val, "item") else val + result["marital_unit"].append(mu_result) + + for i in range(n_households): + hh_result = {} + for var in sim.output_dataset.household.columns: + val = sim.output_dataset.household[var].iloc[i] + hh_result[var] = float(val) if hasattr(val, "item") else val + result["household"].append(hh_result) + + return result + + +def _get_household_policy_data(policy_id, session) -> dict | None: + """Get policy data for household calculation.""" + if policy_id is None: + return None + + from policyengine_api.models import Policy + + db_policy = session.get(Policy, policy_id) + if not db_policy: + return None + + return { + "name": db_policy.name, + "description": db_policy.description, + "parameter_values": [ + { + "parameter_name": pv.parameter.name if pv.parameter else None, + "value": pv.value_json.get("value") + if isinstance(pv.value_json, dict) + else pv.value_json, + "start_date": pv.start_date.isoformat() if pv.start_date else None, + "end_date": pv.end_date.isoformat() if pv.end_date else None, + } + for pv in db_policy.parameter_values + if pv.parameter + ], + } diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index 4d64c02..c49b457 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -11,6 +11,7 @@ from .dataset_version import DatasetVersion, DatasetVersionCreate, DatasetVersionRead from .decile_impact import DecileImpact, DecileImpactCreate, DecileImpactRead from .dynamic import Dynamic, DynamicCreate, DynamicRead +from .household import Household, HouseholdCreate, HouseholdRead from .household_job import ( HouseholdJob, HouseholdJobCreate, @@ -35,7 +36,13 @@ ProgramStatisticsRead, ) from .report import Report, ReportCreate, ReportRead, ReportStatus -from .simulation import Simulation, SimulationCreate, SimulationRead, SimulationStatus +from .simulation import ( + Simulation, + SimulationCreate, + SimulationRead, + SimulationStatus, + SimulationType, +) from .tax_benefit_model import ( TaxBenefitModel, TaxBenefitModelCreate, @@ -47,6 +54,12 @@ TaxBenefitModelVersionRead, ) from .user import User, UserCreate, UserRead +from .user_household_association import ( + UserHouseholdAssociation, + UserHouseholdAssociationCreate, + UserHouseholdAssociationRead, + UserHouseholdAssociationUpdate, +) from .variable import Variable, VariableCreate, VariableRead __all__ = [ @@ -72,6 +85,9 @@ "Dynamic", "DynamicCreate", "DynamicRead", + "Household", + "HouseholdCreate", + "HouseholdRead", "HouseholdJob", "HouseholdJobCreate", "HouseholdJobRead", @@ -102,6 +118,7 @@ "SimulationCreate", "SimulationRead", "SimulationStatus", + "SimulationType", "TaxBenefitModel", "TaxBenefitModelCreate", "TaxBenefitModelRead", @@ -110,6 +127,10 @@ "TaxBenefitModelVersionRead", "User", "UserCreate", + "UserHouseholdAssociation", + "UserHouseholdAssociationCreate", + "UserHouseholdAssociationRead", + "UserHouseholdAssociationUpdate", "UserRead", "Variable", "VariableCreate", diff --git a/src/policyengine_api/models/household.py b/src/policyengine_api/models/household.py new file mode 100644 index 0000000..8a96850 --- /dev/null +++ b/src/policyengine_api/models/household.py @@ -0,0 +1,54 @@ +"""Stored household definition model.""" + +from datetime import datetime, timezone +from typing import Any, Literal +from uuid import UUID, uuid4 + +from sqlalchemy import JSON +from sqlmodel import Column, Field, SQLModel + + +class HouseholdBase(SQLModel): + """Base household fields.""" + + tax_benefit_model_name: str + year: int + label: str | None = None + household_data: dict[str, Any] = Field(sa_column=Column(JSON, nullable=False)) + + +class Household(HouseholdBase, table=True): + """Stored household database model.""" + + __tablename__ = "households" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class HouseholdCreate(SQLModel): + """Schema for creating a stored household. + + Accepts the flat structure matching the frontend Household interface: + people as an array, entity groups as optional dicts. + """ + + tax_benefit_model_name: Literal["policyengine_us", "policyengine_uk"] + year: int + label: str | None = None + people: list[dict[str, Any]] + tax_unit: dict[str, Any] | None = None + family: dict[str, Any] | None = None + spm_unit: dict[str, Any] | None = None + marital_unit: dict[str, Any] | None = None + household: dict[str, Any] | None = None + benunit: dict[str, Any] | None = None + + +class HouseholdRead(HouseholdCreate): + """Schema for reading a stored household.""" + + id: UUID + created_at: datetime + updated_at: datetime diff --git a/src/policyengine_api/models/report.py b/src/policyengine_api/models/report.py index ee1b678..bc2cd40 100644 --- a/src/policyengine_api/models/report.py +++ b/src/policyengine_api/models/report.py @@ -19,6 +19,7 @@ class ReportBase(SQLModel): label: str description: str | None = None + report_type: str | None = None user_id: UUID | None = Field(default=None, foreign_key="users.id") markdown: str | None = Field(default=None, sa_column=Column(Text)) parent_report_id: UUID | None = Field(default=None, foreign_key="reports.id") diff --git a/src/policyengine_api/models/simulation.py b/src/policyengine_api/models/simulation.py index b23141e..985db3e 100644 --- a/src/policyengine_api/models/simulation.py +++ b/src/policyengine_api/models/simulation.py @@ -1,13 +1,16 @@ from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING -from uuid import UUID, uuid4 +from typing import TYPE_CHECKING, Any +from sqlalchemy import Column +from sqlalchemy.dialects.postgresql import JSON from sqlmodel import Field, Relationship, SQLModel +from uuid import UUID, uuid4 if TYPE_CHECKING: from .dataset import Dataset from .dynamic import Dynamic + from .household import Household from .policy import Policy from .tax_benefit_model_version import TaxBenefitModelVersion @@ -21,10 +24,19 @@ class SimulationStatus(str, Enum): FAILED = "failed" +class SimulationType(str, Enum): + """Type of simulation.""" + + HOUSEHOLD = "household" + ECONOMY = "economy" + + class SimulationBase(SQLModel): """Base simulation fields.""" - dataset_id: UUID = Field(foreign_key="datasets.id") + simulation_type: SimulationType = SimulationType.ECONOMY + dataset_id: UUID | None = Field(default=None, foreign_key="datasets.id") + household_id: UUID | None = Field(default=None, foreign_key="households.id") policy_id: UUID | None = Field(default=None, foreign_key="policies.id") dynamic_id: UUID | None = Field(default=None, foreign_key="dynamics.id") tax_benefit_model_version_id: UUID = Field( @@ -45,6 +57,9 @@ class Simulation(SimulationBase, table=True): updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) started_at: datetime | None = None completed_at: datetime | None = None + household_result: dict[str, Any] | None = Field( + default=None, sa_column=Column(JSON) + ) # Relationships dataset: "Dataset" = Relationship( @@ -53,6 +68,12 @@ class Simulation(SimulationBase, table=True): "primaryjoin": "Simulation.dataset_id==Dataset.id", } ) + household: "Household" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[Simulation.household_id]", + "primaryjoin": "Simulation.household_id==Household.id", + } + ) policy: "Policy" = Relationship() dynamic: "Dynamic" = Relationship() tax_benefit_model_version: "TaxBenefitModelVersion" = Relationship() @@ -78,3 +99,4 @@ class SimulationRead(SimulationBase): updated_at: datetime started_at: datetime | None completed_at: datetime | None + household_result: dict[str, Any] | None = None diff --git a/src/policyengine_api/models/user_household_association.py b/src/policyengine_api/models/user_household_association.py new file mode 100644 index 0000000..208279a --- /dev/null +++ b/src/policyengine_api/models/user_household_association.py @@ -0,0 +1,48 @@ +"""User-household association model.""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + + +class UserHouseholdAssociationBase(SQLModel): + """Base association fields.""" + + user_id: UUID = Field(foreign_key="users.id", index=True) + household_id: UUID = Field(foreign_key="households.id", index=True) + country_id: str + label: str | None = None + + +class UserHouseholdAssociation(UserHouseholdAssociationBase, table=True): + """User-household association database model.""" + + __tablename__ = "user_household_associations" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class UserHouseholdAssociationCreate(SQLModel): + """Schema for creating a user-household association.""" + + user_id: UUID + household_id: UUID + country_id: str + label: str | None = None + + +class UserHouseholdAssociationUpdate(SQLModel): + """Schema for updating a user-household association.""" + + label: str | None = None + + +class UserHouseholdAssociationRead(UserHouseholdAssociationBase): + """Schema for reading a user-household association.""" + + id: UUID + created_at: datetime + updated_at: datetime diff --git a/supabase/.temp/cli-latest b/supabase/.temp/cli-latest index 8c68db7..1dd6178 100644 --- a/supabase/.temp/cli-latest +++ b/supabase/.temp/cli-latest @@ -1 +1 @@ -v2.67.1 \ No newline at end of file +v2.75.0 \ No newline at end of file diff --git a/supabase/migrations/20251229000000_add_parameter_values_indexes.sql b/supabase/migrations_archived/20251229000000_add_parameter_values_indexes.sql similarity index 100% rename from supabase/migrations/20251229000000_add_parameter_values_indexes.sql rename to supabase/migrations_archived/20251229000000_add_parameter_values_indexes.sql diff --git a/supabase/migrations/20260103000000_add_poverty_inequality.sql b/supabase/migrations_archived/20260103000000_add_poverty_inequality.sql similarity index 100% rename from supabase/migrations/20260103000000_add_poverty_inequality.sql rename to supabase/migrations_archived/20260103000000_add_poverty_inequality.sql diff --git a/supabase/migrations/20260111000000_add_aggregate_status.sql b/supabase/migrations_archived/20260111000000_add_aggregate_status.sql similarity index 100% rename from supabase/migrations/20260111000000_add_aggregate_status.sql rename to supabase/migrations_archived/20260111000000_add_aggregate_status.sql diff --git a/supabase/migrations_archived/20260203000000_create_households.sql b/supabase/migrations_archived/20260203000000_create_households.sql new file mode 100644 index 0000000..cc1907f --- /dev/null +++ b/supabase/migrations_archived/20260203000000_create_households.sql @@ -0,0 +1,14 @@ +-- Create stored households table for persisting household definitions. + +CREATE TABLE IF NOT EXISTS households ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + tax_benefit_model_name TEXT NOT NULL, + year INTEGER NOT NULL, + label TEXT, + household_data JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX idx_households_model_name ON households (tax_benefit_model_name); +CREATE INDEX idx_households_year ON households (year); diff --git a/supabase/migrations_archived/20260203000001_create_user_household_associations.sql b/supabase/migrations_archived/20260203000001_create_user_household_associations.sql new file mode 100644 index 0000000..3fdcb03 --- /dev/null +++ b/supabase/migrations_archived/20260203000001_create_user_household_associations.sql @@ -0,0 +1,14 @@ +-- Create user-household associations table for linking users to saved households. + +CREATE TABLE IF NOT EXISTS user_household_associations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + household_id UUID NOT NULL REFERENCES households(id) ON DELETE CASCADE, + country_id TEXT NOT NULL, + label TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX idx_user_household_assoc_user ON user_household_associations (user_id); +CREATE INDEX idx_user_household_assoc_household ON user_household_associations (household_id); diff --git a/supabase/migrations_archived/20260203000002_simulation_household_support.sql b/supabase/migrations_archived/20260203000002_simulation_household_support.sql new file mode 100644 index 0000000..6813f07 --- /dev/null +++ b/supabase/migrations_archived/20260203000002_simulation_household_support.sql @@ -0,0 +1,16 @@ +-- Add simulation_type as TEXT (SQLModel enum maps to text) +ALTER TABLE simulations ADD COLUMN simulation_type TEXT NOT NULL DEFAULT 'economy'; + +-- Make dataset_id nullable (was required) +ALTER TABLE simulations ALTER COLUMN dataset_id DROP NOT NULL; + +-- Add household support columns +ALTER TABLE simulations ADD COLUMN household_id UUID REFERENCES households(id); +ALTER TABLE simulations ADD COLUMN household_result JSONB; + +-- Indexes +CREATE INDEX idx_simulations_household ON simulations (household_id); +CREATE INDEX idx_simulations_type ON simulations (simulation_type); + +-- Add report_type to reports +ALTER TABLE reports ADD COLUMN report_type TEXT; diff --git a/test_fixtures/fixtures_household_analysis.py b/test_fixtures/fixtures_household_analysis.py new file mode 100644 index 0000000..573930a --- /dev/null +++ b/test_fixtures/fixtures_household_analysis.py @@ -0,0 +1,366 @@ +"""Fixtures and helpers for household analysis endpoint tests.""" + +from typing import Any +from unittest.mock import patch +from uuid import UUID + +import pytest +from sqlmodel import Session + +from policyengine_api.models import ( + Household, + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +# ============================================================================= +# Sample Calculation Results +# ============================================================================= + + +SAMPLE_UK_BASELINE_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 35000.0, + "income_tax": 4500.0, + "national_insurance": 2800.0, + "net_income": 27700.0, + } + ], + "benunit": [ + { + "universal_credit": 0.0, + "child_benefit": 0.0, + } + ], + "household": [ + { + "region": "LONDON", + "council_tax": 1500.0, + } + ], +} + + +SAMPLE_UK_REFORM_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 35000.0, + "income_tax": 4000.0, + "national_insurance": 2800.0, + "net_income": 28200.0, + } + ], + "benunit": [ + { + "universal_credit": 0.0, + "child_benefit": 0.0, + } + ], + "household": [ + { + "region": "LONDON", + "council_tax": 1500.0, + } + ], +} + + +SAMPLE_US_BASELINE_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 50000.0, + "income_tax": 6000.0, + "fica": 3825.0, + "net_income": 40175.0, + } + ], + "tax_unit": [ + { + "state_code": "CA", + "state_income_tax": 2500.0, + } + ], + "spm_unit": [{"snap": 0.0}], + "family": [{}], + "marital_unit": [{}], + "household": [{"state_fips": 6}], +} + + +SAMPLE_US_REFORM_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 50000.0, + "income_tax": 5500.0, + "fica": 3825.0, + "net_income": 40675.0, + } + ], + "tax_unit": [ + { + "state_code": "CA", + "state_income_tax": 2500.0, + } + ], + "spm_unit": [{"snap": 0.0}], + "family": [{}], + "marital_unit": [{}], + "household": [{"state_fips": 6}], +} + + +# ============================================================================= +# Mock Calculator Functions +# ============================================================================= + + +def mock_calculate_uk_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock UK calculator that returns sample results.""" + if policy_data: + return SAMPLE_UK_REFORM_RESULT + return SAMPLE_UK_BASELINE_RESULT + + +def mock_calculate_us_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock US calculator that returns sample results.""" + if policy_data: + return SAMPLE_US_REFORM_RESULT + return SAMPLE_US_BASELINE_RESULT + + +def mock_calculate_household_failing( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock calculator that raises an exception.""" + raise RuntimeError("Calculation failed") + + +# ============================================================================= +# Pytest Fixtures for Mocking +# ============================================================================= + + +@pytest.fixture +def mock_uk_calculator(): + """Fixture that patches UK calculator with mock.""" + with patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_uk_household, + ) as mock: + yield mock + + +@pytest.fixture +def mock_us_calculator(): + """Fixture that patches US calculator with mock.""" + with patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_us_household, + ) as mock: + yield mock + + +@pytest.fixture +def mock_calculators(): + """Fixture that patches both UK and US calculators.""" + with ( + patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_uk_household, + ) as uk_mock, + patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_us_household, + ) as us_mock, + ): + yield {"uk": uk_mock, "us": us_mock} + + +@pytest.fixture +def mock_failing_calculator(): + """Fixture that patches calculators to fail.""" + with ( + patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_household_failing, + ), + patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_household_failing, + ), + ): + yield + + +# ============================================================================= +# Database Factory Functions +# ============================================================================= + + +def create_tax_benefit_model( + session: Session, + name: str = "policyengine-uk", + description: str = "UK tax benefit model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + model = TaxBenefitModel( + name=name, + description=description, + ) + session.add(model) + session.commit() + session.refresh(model) + return model + + +def create_model_version( + session: Session, + model_id: UUID, + version: str = "1.0.0", + description: str = "Test version", +) -> TaxBenefitModelVersion: + """Create and persist a TaxBenefitModelVersion record.""" + model_version = TaxBenefitModelVersion( + model_id=model_id, + version=version, + description=description, + ) + session.add(model_version) + session.commit() + session.refresh(model_version) + return model_version + + +def create_parameter( + session: Session, + model_version_id: UUID, + name: str = "test_parameter", + label: str = "Test Parameter", + description: str = "A test parameter", +) -> Parameter: + """Create and persist a Parameter record.""" + param = Parameter( + tax_benefit_model_version_id=model_version_id, + name=name, + label=label, + description=description, + ) + session.add(param) + session.commit() + session.refresh(param) + return param + + +def create_policy( + session: Session, + model_version_id: UUID, + name: str = "Test Policy", + description: str = "A test policy", +) -> Policy: + """Create and persist a Policy record.""" + policy = Policy( + tax_benefit_model_version_id=model_version_id, + name=name, + description=description, + ) + session.add(policy) + session.commit() + session.refresh(policy) + return policy + + +def create_policy_with_parameter_value( + session: Session, + model_version_id: UUID, + parameter_id: UUID, + value: float, + name: str = "Test Policy", +) -> Policy: + """Create a Policy with an associated ParameterValue.""" + policy = create_policy(session, model_version_id, name=name) + + param_value = ParameterValue( + policy_id=policy.id, + parameter_id=parameter_id, + value_json={"value": value}, + ) + session.add(param_value) + session.commit() + session.refresh(policy) + return policy + + +def create_household_for_analysis( + session: Session, + tax_benefit_model_name: str = "policyengine_uk", + year: int = 2024, + label: str = "Test household for analysis", +) -> Household: + """Create a household suitable for analysis testing.""" + if tax_benefit_model_name == "policyengine_uk": + household_data = { + "people": [{"age": 30, "employment_income": 35000}], + "benunit": {}, + "household": {"region": "LONDON"}, + } + else: + household_data = { + "people": [{"age": 30, "employment_income": 50000}], + "tax_unit": {"state_code": "CA"}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_fips": 6}, + } + + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data=household_data, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def setup_uk_model_and_version( + session: Session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create UK model and version for testing.""" + model = create_tax_benefit_model( + session, name="policyengine-uk", description="UK model" + ) + version = create_model_version(session, model.id) + return model, version + + +def setup_us_model_and_version( + session: Session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create US model and version for testing.""" + model = create_tax_benefit_model( + session, name="policyengine-us", description="US model" + ) + version = create_model_version(session, model.id) + return model, version diff --git a/test_fixtures/fixtures_households.py b/test_fixtures/fixtures_households.py new file mode 100644 index 0000000..4e676f4 --- /dev/null +++ b/test_fixtures/fixtures_households.py @@ -0,0 +1,66 @@ +"""Fixtures and helpers for household CRUD tests.""" + +from policyengine_api.models import Household + +# ----------------------------------------------------------------------------- +# Request payloads (match HouseholdCreate schema) +# ----------------------------------------------------------------------------- + +MOCK_US_HOUSEHOLD_CREATE = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "label": "US test household", + "people": [ + {"age": 30, "employment_income": 50000}, + {"age": 28, "employment_income": 30000}, + ], + "tax_unit": {}, + "family": {}, + "household": {"state_name": "CA"}, +} + +MOCK_UK_HOUSEHOLD_CREATE = { + "tax_benefit_model_name": "policyengine_uk", + "year": 2024, + "label": "UK test household", + "people": [ + {"age": 40, "employment_income": 35000}, + ], + "benunit": {"is_married": False}, + "household": {"region": "LONDON"}, +} + +MOCK_HOUSEHOLD_MINIMAL = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "people": [{"age": 25}], +} + + +# ----------------------------------------------------------------------------- +# Factory functions +# ----------------------------------------------------------------------------- + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str | None = "Test household", + people: list | None = None, + **entity_groups, +) -> Household: + """Create and persist a Household record.""" + household_data = {"people": people or [{"age": 30}]} + household_data.update(entity_groups) + + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data=household_data, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/test_fixtures/fixtures_policy_reform.py b/test_fixtures/fixtures_policy_reform.py new file mode 100644 index 0000000..f7534a5 --- /dev/null +++ b/test_fixtures/fixtures_policy_reform.py @@ -0,0 +1,282 @@ +"""Fixtures for policy reform conversion tests.""" + +from dataclasses import dataclass +from datetime import date, datetime +from typing import Any + + +# ============================================================================= +# Mock objects for testing _pe_policy_to_reform_dict +# ============================================================================= + + +@dataclass +class MockParameter: + """Mock policyengine.core.models.parameter.Parameter.""" + + name: str + + +@dataclass +class MockParameterValue: + """Mock policyengine.core.models.parameter_value.ParameterValue.""" + + parameter: MockParameter | None + value: Any + start_date: date | datetime | str | None + + +@dataclass +class MockPolicy: + """Mock policyengine.core.policy.Policy.""" + + parameter_values: list[MockParameterValue] | None + + +# ============================================================================= +# Test data constants +# ============================================================================= + +# Simple policy with single parameter change +SIMPLE_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date=date(2024, 1, 1), + ) + ] +) + +SIMPLE_POLICY_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000} +} + +# Policy with multiple parameter changes +MULTI_PARAM_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date=date(2024, 1, 1), + ), + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.refundable.fully_refundable"), + value=True, + start_date=date(2024, 1, 1), + ), + MockParameterValue( + parameter=MockParameter(name="gov.irs.income.bracket.rates.1"), + value=0.12, + start_date=date(2024, 1, 1), + ), + ] +) + +MULTI_PARAM_POLICY_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, + "gov.irs.credits.ctc.refundable.fully_refundable": {"2024-01-01": True}, + "gov.irs.income.bracket.rates.1": {"2024-01-01": 0.12}, +} + +# Policy with same parameter at different dates +MULTI_DATE_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=2500, + start_date=date(2024, 1, 1), + ), + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date=date(2025, 1, 1), + ), + ] +) + +MULTI_DATE_POLICY_EXPECTED = { + "gov.irs.credits.ctc.amount.base": { + "2024-01-01": 2500, + "2025-01-01": 3000, + } +} + +# Policy with datetime start_date (has time component) +DATETIME_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date=datetime(2024, 1, 1, 12, 30, 45), + ) + ] +) + +DATETIME_POLICY_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000} +} + +# Policy with ISO string start_date +ISO_STRING_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date="2024-01-01T00:00:00", + ) + ] +) + +ISO_STRING_POLICY_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000} +} + +# Empty policy (no parameter values) +EMPTY_POLICY = MockPolicy(parameter_values=[]) + +# None policy +NONE_POLICY = None + +# Policy with None parameter_values +NONE_PARAM_VALUES_POLICY = MockPolicy(parameter_values=None) + +# Policy with invalid entries (missing parameter or start_date) +INVALID_ENTRIES_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=None, # Missing parameter + value=3000, + start_date=date(2024, 1, 1), + ), + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date=None, # Missing start_date + ), + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.eitc.max.0"), + value=600, + start_date=date(2024, 1, 1), # This one is valid + ), + ] +) + +INVALID_ENTRIES_POLICY_EXPECTED = { + "gov.irs.credits.eitc.max.0": {"2024-01-01": 600} +} + + +# ============================================================================= +# Test data for _merge_reform_dicts +# ============================================================================= + +REFORM_DICT_1 = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 2000}, + "gov.irs.income.bracket.rates.1": {"2024-01-01": 0.10}, +} + +REFORM_DICT_2 = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, # Overwrites + "gov.irs.credits.eitc.max.0": {"2024-01-01": 600}, # New param +} + +MERGED_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, # From reform2 + "gov.irs.income.bracket.rates.1": {"2024-01-01": 0.10}, # From reform1 + "gov.irs.credits.eitc.max.0": {"2024-01-01": 600}, # From reform2 +} + +REFORM_DICT_3 = { + "gov.irs.credits.ctc.amount.base": { + "2024-01-01": 2500, + "2025-01-01": 2700, + }, +} + +REFORM_DICT_4 = { + "gov.irs.credits.ctc.amount.base": { + "2025-01-01": 3000, # Overwrites 2025 date + "2026-01-01": 3500, # New date + }, +} + +MERGED_MULTI_DATE_EXPECTED = { + "gov.irs.credits.ctc.amount.base": { + "2024-01-01": 2500, # From reform3 + "2025-01-01": 3000, # From reform4 (overwrites) + "2026-01-01": 3500, # From reform4 (new) + }, +} + + +# ============================================================================= +# Test data for household calculation policy conversion +# ============================================================================= + +# Policy data as it comes from the API (stored in database) +HOUSEHOLD_POLICY_DATA = { + "parameter_values": [ + { + "parameter_name": "gov.irs.credits.ctc.amount.base", + "value": 3000, + "start_date": "2024-01-01", + }, + { + "parameter_name": "gov.irs.credits.ctc.refundable.fully_refundable", + "value": True, + "start_date": "2024-01-01", + }, + ] +} + +HOUSEHOLD_POLICY_DATA_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, + "gov.irs.credits.ctc.refundable.fully_refundable": {"2024-01-01": True}, +} + +# Policy data with ISO datetime strings +HOUSEHOLD_POLICY_DATA_DATETIME = { + "parameter_values": [ + { + "parameter_name": "gov.irs.credits.ctc.amount.base", + "value": 3000, + "start_date": "2024-01-01T00:00:00.000Z", + }, + ] +} + +HOUSEHOLD_POLICY_DATA_DATETIME_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, +} + +# Empty policy data +HOUSEHOLD_EMPTY_POLICY_DATA = {"parameter_values": []} + +# None policy data +HOUSEHOLD_NONE_POLICY_DATA = None + +# Policy data with missing fields +HOUSEHOLD_INCOMPLETE_POLICY_DATA = { + "parameter_values": [ + { + "parameter_name": None, # Missing + "value": 3000, + "start_date": "2024-01-01", + }, + { + "parameter_name": "gov.irs.credits.ctc.amount.base", + "value": 3000, + "start_date": None, # Missing + }, + { + "parameter_name": "gov.irs.credits.eitc.max.0", + "value": 600, + "start_date": "2024-01-01", # Valid + }, + ] +} + +HOUSEHOLD_INCOMPLETE_POLICY_DATA_EXPECTED = { + "gov.irs.credits.eitc.max.0": {"2024-01-01": 600}, +} diff --git a/test_fixtures/fixtures_user_household_associations.py b/test_fixtures/fixtures_user_household_associations.py new file mode 100644 index 0000000..66b0835 --- /dev/null +++ b/test_fixtures/fixtures_user_household_associations.py @@ -0,0 +1,62 @@ +"""Fixtures and helpers for user-household association tests.""" + +from uuid import UUID + +from policyengine_api.models import Household, User, UserHouseholdAssociation + +# ----------------------------------------------------------------------------- +# Factory functions +# ----------------------------------------------------------------------------- + + +def create_user( + session, + first_name: str = "Test", + last_name: str = "User", + email: str = "test@example.com", +) -> User: + """Create and persist a User record.""" + record = User(first_name=first_name, last_name=last_name, email=email) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str | None = "Test household", +) -> Household: + """Create and persist a Household record.""" + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data={"people": [{"age": 30}]}, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_association( + session, + user_id: UUID, + household_id: UUID, + country_id: str = "us", + label: str | None = "My household", +) -> UserHouseholdAssociation: + """Create and persist a UserHouseholdAssociation record.""" + record = UserHouseholdAssociation( + user_id=user_id, + household_id=household_id, + country_id=country_id, + label=label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/tests/test_analysis_household_impact.py b/tests/test_analysis_household_impact.py new file mode 100644 index 0000000..23465c7 --- /dev/null +++ b/tests/test_analysis_household_impact.py @@ -0,0 +1,526 @@ +"""Tests for household impact analysis endpoints.""" + +from datetime import date +from uuid import UUID, uuid4 + +import pytest + +from test_fixtures.fixtures_household_analysis import ( + SAMPLE_UK_BASELINE_RESULT, + SAMPLE_UK_REFORM_RESULT, + SAMPLE_US_BASELINE_RESULT, + SAMPLE_US_REFORM_RESULT, + create_household_for_analysis, + create_policy, + setup_uk_model_and_version, + setup_us_model_and_version, +) +from policyengine_api.api.household_analysis import ( + UK_CONFIG, + US_CONFIG, + _ensure_list, + _extract_value, + _format_date, + compute_entity_diff, + compute_entity_list_diff, + compute_household_impact, + compute_variable_diff, + get_calculator, + get_country_config, +) +from policyengine_api.models import Report, ReportStatus, Simulation, SimulationType + + +# --------------------------------------------------------------------------- +# Unit tests for helper functions +# --------------------------------------------------------------------------- + + +class TestEnsureList: + """Tests for _ensure_list helper.""" + + def test_none_returns_empty_list(self): + assert _ensure_list(None) == [] + + def test_list_returns_same_list(self): + input_list = [1, 2, 3] + assert _ensure_list(input_list) == input_list + + def test_dict_wrapped_in_list(self): + input_dict = {"key": "value"} + result = _ensure_list(input_dict) + assert result == [input_dict] + + def test_empty_list_returns_empty_list(self): + assert _ensure_list([]) == [] + + +class TestExtractValue: + """Tests for _extract_value helper.""" + + def test_dict_with_value_key(self): + assert _extract_value({"value": 100}) == 100 + + def test_dict_without_value_key(self): + assert _extract_value({"other": 100}) is None + + def test_non_dict_returns_as_is(self): + assert _extract_value(100) == 100 + assert _extract_value("string") == "string" + assert _extract_value([1, 2]) == [1, 2] + + +class TestFormatDate: + """Tests for _format_date helper.""" + + def test_none_returns_none(self): + assert _format_date(None) is None + + def test_date_object_formatted(self): + d = date(2024, 1, 15) + assert _format_date(d) == "2024-01-15" + + def test_string_returns_string(self): + assert _format_date("2024-01-15") == "2024-01-15" + + +class TestComputeVariableDiff: + """Tests for compute_variable_diff helper.""" + + def test_numeric_values_return_diff(self): + result = compute_variable_diff(100, 150) + assert result == {"baseline": 100, "reform": 150, "change": 50} + + def test_negative_change(self): + result = compute_variable_diff(150, 100) + assert result == {"baseline": 150, "reform": 100, "change": -50} + + def test_float_values(self): + result = compute_variable_diff(100.5, 200.5) + assert result == {"baseline": 100.5, "reform": 200.5, "change": 100.0} + + def test_non_numeric_baseline_returns_none(self): + assert compute_variable_diff("string", 100) is None + + def test_non_numeric_reform_returns_none(self): + assert compute_variable_diff(100, "string") is None + + def test_both_non_numeric_returns_none(self): + assert compute_variable_diff("a", "b") is None + + +class TestComputeEntityDiff: + """Tests for compute_entity_diff helper.""" + + def test_computes_diff_for_numeric_keys(self): + baseline = {"income": 1000, "tax": 200, "name": "John"} + reform = {"income": 1000, "tax": 150, "name": "John"} + result = compute_entity_diff(baseline, reform) + + assert "income" in result + assert result["income"]["change"] == 0 + assert "tax" in result + assert result["tax"]["change"] == -50 + assert "name" not in result + + def test_missing_key_in_reform_skipped(self): + baseline = {"income": 1000, "tax": 200} + reform = {"income": 1000} + result = compute_entity_diff(baseline, reform) + + assert "income" in result + assert "tax" not in result + + def test_empty_entities(self): + assert compute_entity_diff({}, {}) == {} + + +class TestComputeEntityListDiff: + """Tests for compute_entity_list_diff helper.""" + + def test_computes_diff_for_each_pair(self): + baseline_list = [{"income": 100}, {"income": 200}] + reform_list = [{"income": 120}, {"income": 180}] + result = compute_entity_list_diff(baseline_list, reform_list) + + assert len(result) == 2 + assert result[0]["income"]["change"] == 20 + assert result[1]["income"]["change"] == -20 + + def test_empty_lists(self): + assert compute_entity_list_diff([], []) == [] + + +class TestComputeHouseholdImpact: + """Tests for compute_household_impact helper.""" + + def test_uk_household_impact(self): + result = compute_household_impact( + SAMPLE_UK_BASELINE_RESULT, + SAMPLE_UK_REFORM_RESULT, + UK_CONFIG, + ) + + assert "person" in result + assert "benunit" in result + assert "household" in result + + # Check person income_tax changed + person_diff = result["person"][0] + assert "income_tax" in person_diff + assert person_diff["income_tax"]["baseline"] == 4500.0 + assert person_diff["income_tax"]["reform"] == 4000.0 + assert person_diff["income_tax"]["change"] == -500.0 + + def test_us_household_impact(self): + result = compute_household_impact( + SAMPLE_US_BASELINE_RESULT, + SAMPLE_US_REFORM_RESULT, + US_CONFIG, + ) + + assert "person" in result + assert "tax_unit" in result + assert "spm_unit" in result + assert "family" in result + assert "marital_unit" in result + assert "household" in result + + # Check person income_tax changed + person_diff = result["person"][0] + assert person_diff["income_tax"]["change"] == -500.0 + + def test_missing_entity_skipped(self): + baseline = {"person": [{"income": 100}]} + reform = {"person": [{"income": 120}]} + result = compute_household_impact(baseline, reform, UK_CONFIG) + + assert "person" in result + assert "benunit" not in result + assert "household" not in result + + +class TestGetCountryConfig: + """Tests for get_country_config helper.""" + + def test_uk_model_returns_uk_config(self): + config = get_country_config("policyengine_uk") + assert config == UK_CONFIG + assert config.name == "uk" + assert "benunit" in config.entity_types + + def test_us_model_returns_us_config(self): + config = get_country_config("policyengine_us") + assert config == US_CONFIG + assert config.name == "us" + assert "tax_unit" in config.entity_types + + def test_unknown_model_defaults_to_us(self): + config = get_country_config("unknown_model") + assert config == US_CONFIG + + +class TestGetCalculator: + """Tests for get_calculator helper.""" + + def test_uk_model_returns_uk_calculator(self): + from policyengine_api.api.household_analysis import calculate_uk_household + + calc = get_calculator("policyengine_uk") + assert calc == calculate_uk_household + + def test_us_model_returns_us_calculator(self): + from policyengine_api.api.household_analysis import calculate_us_household + + calc = get_calculator("policyengine_us") + assert calc == calculate_us_household + + def test_unknown_model_defaults_to_us(self): + from policyengine_api.api.household_analysis import calculate_us_household + + calc = get_calculator("unknown_model") + assert calc == calculate_us_household + + +# --------------------------------------------------------------------------- +# Validation tests (no database required beyond session fixture) +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactValidation: + """Tests for request validation.""" + + def test_missing_household_id(self, client): + """Test that missing household_id returns 422.""" + response = client.post( + "/analysis/household-impact", + json={}, + ) + assert response.status_code == 422 + + def test_invalid_uuid(self, client): + """Test that invalid UUID returns 422.""" + response = client.post( + "/analysis/household-impact", + json={ + "household_id": "not-a-uuid", + }, + ) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# 404 tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactNotFound: + """Tests for 404 responses.""" + + def test_household_not_found(self, client, session): + """Test that non-existent household returns 404.""" + # Need model for the model version lookup + setup_uk_model_and_version(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(uuid4()), + }, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_policy_not_found(self, client, session): + """Test that non-existent policy returns 404.""" + setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(uuid4()), + }, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_get_report_not_found(self, client): + """Test that GET with non-existent report_id returns 404.""" + response = client.get(f"/analysis/household-impact/{uuid4()}") + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# Record creation tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactRecordCreation: + """Tests for correct record creation.""" + + def test_single_run_creates_one_simulation(self, client, session): + """Single run (no policy_id) creates one simulation.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + }, + ) + # May fail during calculation since policyengine not available, + # but should create records + data = response.json() + assert "report_id" in data + assert data["report_type"] == "household_single" + assert data["baseline_simulation"] is not None + assert data["reform_simulation"] is None + + def test_comparison_creates_two_simulations(self, client, session): + """Comparison (with policy_id) creates two simulations.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy = create_policy(session, version.id) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy.id), + }, + ) + data = response.json() + assert "report_id" in data + assert data["report_type"] == "household_comparison" + assert data["baseline_simulation"] is not None + assert data["reform_simulation"] is not None + + def test_simulation_type_is_household(self, client, session): + """Created simulations have simulation_type=HOUSEHOLD.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + }, + ) + data = response.json() + + # Check simulation in database (convert string to UUID for query) + sim_id = UUID(data["baseline_simulation"]["id"]) + sim = session.get(Simulation, sim_id) + assert sim is not None + assert sim.simulation_type == SimulationType.HOUSEHOLD + assert sim.household_id == household.id + assert sim.dataset_id is None + + def test_report_links_simulations(self, client, session): + """Report correctly links baseline and reform simulations.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy = create_policy(session, version.id) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy.id), + }, + ) + data = response.json() + + # Check report in database (convert string to UUID for query) + report = session.get(Report, UUID(data["report_id"])) + assert report is not None + assert report.baseline_simulation_id == UUID(data["baseline_simulation"]["id"]) + assert report.reform_simulation_id == UUID(data["reform_simulation"]["id"]) + assert report.report_type == "household_comparison" + + +# --------------------------------------------------------------------------- +# Deduplication tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactDeduplication: + """Tests for simulation/report deduplication.""" + + def test_same_request_returns_same_simulation(self, client, session): + """Same household + same parameters returns same simulation ID.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + # First request + response1 = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data1 = response1.json() + + # Second request with same parameters + response2 = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data2 = response2.json() + + # Should return same IDs + assert data1["report_id"] == data2["report_id"] + assert data1["baseline_simulation"]["id"] == data2["baseline_simulation"]["id"] + + def test_different_policy_creates_different_simulation(self, client, session): + """Different policy creates different simulation.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy1 = create_policy(session, version.id, name="Policy 1") + policy2 = create_policy(session, version.id, name="Policy 2") + + # Request with policy1 + response1 = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy1.id), + }, + ) + data1 = response1.json() + + # Request with policy2 + response2 = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy2.id), + }, + ) + data2 = response2.json() + + # Reports should be different + assert data1["report_id"] != data2["report_id"] + # Reform simulations should be different + assert ( + data1["reform_simulation"]["id"] != data2["reform_simulation"]["id"] + ) + # Baseline simulations should be the same (same household, no policy) + assert ( + data1["baseline_simulation"]["id"] == data2["baseline_simulation"]["id"] + ) + + +# --------------------------------------------------------------------------- +# GET endpoint tests +# --------------------------------------------------------------------------- + + +class TestGetHouseholdImpact: + """Tests for GET /analysis/household-impact/{report_id}.""" + + def test_get_returns_report_data(self, client, session): + """GET returns report with simulation info.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + # Create report via POST + post_response = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + report_id = post_response.json()["report_id"] + + # GET the report + get_response = client.get(f"/analysis/household-impact/{report_id}") + assert get_response.status_code == 200 + + data = get_response.json() + assert data["report_id"] == report_id + assert data["report_type"] == "household_single" + assert data["baseline_simulation"] is not None + + +# --------------------------------------------------------------------------- +# US household tests +# --------------------------------------------------------------------------- + + +class TestUSHouseholdImpact: + """Tests specific to US households.""" + + def test_us_household_creates_simulation(self, client, session): + """US household creates simulation with correct model.""" + _, version = setup_us_model_and_version(session) + household = create_household_for_analysis( + session, tax_benefit_model_name="policyengine_us" + ) + + response = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data = response.json() + assert "report_id" in data + assert data["baseline_simulation"] is not None diff --git a/tests/test_households.py b/tests/test_households.py new file mode 100644 index 0000000..4c60062 --- /dev/null +++ b/tests/test_households.py @@ -0,0 +1,155 @@ +"""Tests for stored household CRUD endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_households import ( + MOCK_HOUSEHOLD_MINIMAL, + MOCK_UK_HOUSEHOLD_CREATE, + MOCK_US_HOUSEHOLD_CREATE, + create_household, +) + +# --------------------------------------------------------------------------- +# POST /households +# --------------------------------------------------------------------------- + + +def test_create_us_household(client): + """Create a US household returns 201 with id and timestamps.""" + response = client.post("/households", json=MOCK_US_HOUSEHOLD_CREATE) + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["tax_benefit_model_name"] == "policyengine_us" + assert data["year"] == 2024 + assert data["label"] == "US test household" + + +def test_create_household_returns_people_and_entities(client): + """Created household response includes people and entity groups.""" + response = client.post("/households", json=MOCK_US_HOUSEHOLD_CREATE) + data = response.json() + assert len(data["people"]) == 2 + assert data["people"][0]["age"] == 30 + assert data["people"][0]["employment_income"] == 50000 + assert data["household"] == {"state_name": "CA"} + assert data["tax_unit"] == {} + assert data["family"] == {} + + +def test_create_uk_household(client): + """Create a UK household with benunit.""" + response = client.post("/households", json=MOCK_UK_HOUSEHOLD_CREATE) + assert response.status_code == 201 + data = response.json() + assert data["tax_benefit_model_name"] == "policyengine_uk" + assert data["benunit"] == {"is_married": False} + assert data["household"] == {"region": "LONDON"} + + +def test_create_household_minimal(client): + """Create a household with minimal fields.""" + response = client.post("/households", json=MOCK_HOUSEHOLD_MINIMAL) + assert response.status_code == 201 + data = response.json() + assert data["label"] is None + assert data["tax_unit"] is None + assert data["benunit"] is None + + +def test_create_household_invalid_model_name(client): + """Reject invalid tax_benefit_model_name.""" + payload = {**MOCK_HOUSEHOLD_MINIMAL, "tax_benefit_model_name": "invalid"} + response = client.post("/households", json=payload) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /households/{id} +# --------------------------------------------------------------------------- + + +def test_get_household(client, session): + """Get a stored household by ID.""" + record = create_household(session) + response = client.get(f"/households/{record.id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(record.id) + assert data["tax_benefit_model_name"] == "policyengine_us" + + +def test_get_household_not_found(client): + """Get a non-existent household returns 404.""" + fake_id = uuid4() + response = client.get(f"/households/{fake_id}") + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# GET /households +# --------------------------------------------------------------------------- + + +def test_list_households_empty(client): + """List households returns empty list when none exist.""" + response = client.get("/households") + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_households_with_data(client, session): + """List households returns all stored households.""" + create_household(session, label="first") + create_household(session, label="second") + response = client.get("/households") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_households_filter_by_model_name(client, session): + """Filter households by tax_benefit_model_name.""" + create_household(session, tax_benefit_model_name="policyengine_us") + create_household(session, tax_benefit_model_name="policyengine_uk") + response = client.get( + "/households", params={"tax_benefit_model_name": "policyengine_uk"} + ) + data = response.json() + assert len(data) == 1 + assert data[0]["tax_benefit_model_name"] == "policyengine_uk" + + +def test_list_households_limit_and_offset(client, session): + """Respect limit and offset pagination.""" + for i in range(5): + create_household(session, label=f"household-{i}") + response = client.get("/households", params={"limit": 2, "offset": 1}) + data = response.json() + assert len(data) == 2 + + +# --------------------------------------------------------------------------- +# DELETE /households/{id} +# --------------------------------------------------------------------------- + + +def test_delete_household(client, session): + """Delete a household returns 204.""" + record = create_household(session) + response = client.delete(f"/households/{record.id}") + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/households/{record.id}") + assert response.status_code == 404 + + +def test_delete_household_not_found(client): + """Delete a non-existent household returns 404.""" + fake_id = uuid4() + response = client.delete(f"/households/{fake_id}") + assert response.status_code == 404 diff --git a/tests/test_policy_reform.py b/tests/test_policy_reform.py new file mode 100644 index 0000000..cfee3b8 --- /dev/null +++ b/tests/test_policy_reform.py @@ -0,0 +1,327 @@ +"""Tests for policy reform conversion logic. + +Tests the helper functions that convert policy objects to reform dict format +for use with Microsimulation. These are critical for fixing the bug where +reforms weren't being applied to economy-wide and household simulations. +""" + +import sys +from unittest.mock import MagicMock + +import pytest + +# Mock modal before importing modal_app +sys.modules["modal"] = MagicMock() + +from test_fixtures.fixtures_policy_reform import ( + DATETIME_POLICY, + DATETIME_POLICY_EXPECTED, + EMPTY_POLICY, + HOUSEHOLD_EMPTY_POLICY_DATA, + HOUSEHOLD_INCOMPLETE_POLICY_DATA, + HOUSEHOLD_INCOMPLETE_POLICY_DATA_EXPECTED, + HOUSEHOLD_NONE_POLICY_DATA, + HOUSEHOLD_POLICY_DATA, + HOUSEHOLD_POLICY_DATA_DATETIME, + HOUSEHOLD_POLICY_DATA_DATETIME_EXPECTED, + HOUSEHOLD_POLICY_DATA_EXPECTED, + INVALID_ENTRIES_POLICY, + INVALID_ENTRIES_POLICY_EXPECTED, + ISO_STRING_POLICY, + ISO_STRING_POLICY_EXPECTED, + MERGED_EXPECTED, + MERGED_MULTI_DATE_EXPECTED, + MULTI_DATE_POLICY, + MULTI_DATE_POLICY_EXPECTED, + MULTI_PARAM_POLICY, + MULTI_PARAM_POLICY_EXPECTED, + NONE_PARAM_VALUES_POLICY, + NONE_POLICY, + REFORM_DICT_1, + REFORM_DICT_2, + REFORM_DICT_3, + REFORM_DICT_4, + SIMPLE_POLICY, + SIMPLE_POLICY_EXPECTED, +) + +# Import after mocking modal +from policyengine_api.modal_app import _merge_reform_dicts, _pe_policy_to_reform_dict + + +class TestPePolicyToReformDict: + """Tests for _pe_policy_to_reform_dict function.""" + + # ========================================================================= + # Given: Valid policy with single parameter + # ========================================================================= + + def test__given_simple_policy_with_date_object__then_returns_correct_reform_dict( + self, + ): + """Given a policy with a single parameter using date object, + then returns correctly formatted reform dict.""" + # When + result = _pe_policy_to_reform_dict(SIMPLE_POLICY) + + # Then + assert result == SIMPLE_POLICY_EXPECTED + + def test__given_policy_with_datetime_object__then_extracts_date_correctly(self): + """Given a policy with datetime start_date (has time component), + then extracts just the date part for the reform dict.""" + # When + result = _pe_policy_to_reform_dict(DATETIME_POLICY) + + # Then + assert result == DATETIME_POLICY_EXPECTED + + def test__given_policy_with_iso_string_date__then_parses_date_correctly(self): + """Given a policy with ISO string start_date, + then parses and extracts the date correctly.""" + # When + result = _pe_policy_to_reform_dict(ISO_STRING_POLICY) + + # Then + assert result == ISO_STRING_POLICY_EXPECTED + + # ========================================================================= + # Given: Policy with multiple parameters + # ========================================================================= + + def test__given_policy_with_multiple_parameters__then_includes_all_in_dict(self): + """Given a policy with multiple parameter changes, + then includes all parameters in the reform dict.""" + # When + result = _pe_policy_to_reform_dict(MULTI_PARAM_POLICY) + + # Then + assert result == MULTI_PARAM_POLICY_EXPECTED + + def test__given_policy_with_same_param_multiple_dates__then_includes_all_dates( + self, + ): + """Given a policy with the same parameter changed at different dates, + then includes all date entries for that parameter.""" + # When + result = _pe_policy_to_reform_dict(MULTI_DATE_POLICY) + + # Then + assert result == MULTI_DATE_POLICY_EXPECTED + + # ========================================================================= + # Given: Empty or None policy + # ========================================================================= + + def test__given_none_policy__then_returns_none(self): + """Given None as policy, + then returns None.""" + # When + result = _pe_policy_to_reform_dict(NONE_POLICY) + + # Then + assert result is None + + def test__given_policy_with_empty_parameter_values__then_returns_none(self): + """Given a policy with empty parameter_values list, + then returns None.""" + # When + result = _pe_policy_to_reform_dict(EMPTY_POLICY) + + # Then + assert result is None + + def test__given_policy_with_none_parameter_values__then_returns_none(self): + """Given a policy with parameter_values=None, + then returns None.""" + # When + result = _pe_policy_to_reform_dict(NONE_PARAM_VALUES_POLICY) + + # Then + assert result is None + + # ========================================================================= + # Given: Policy with invalid entries + # ========================================================================= + + def test__given_policy_with_invalid_entries__then_skips_invalid_keeps_valid(self): + """Given a policy with some invalid entries (missing parameter or date), + then skips invalid entries and keeps valid ones.""" + # When + result = _pe_policy_to_reform_dict(INVALID_ENTRIES_POLICY) + + # Then + assert result == INVALID_ENTRIES_POLICY_EXPECTED + + +class TestMergeReformDicts: + """Tests for _merge_reform_dicts function.""" + + # ========================================================================= + # Given: Two valid reform dicts + # ========================================================================= + + def test__given_two_reform_dicts__then_merges_with_second_taking_precedence(self): + """Given two reform dicts with overlapping parameters, + then merges them with the second dict taking precedence.""" + # When + result = _merge_reform_dicts(REFORM_DICT_1, REFORM_DICT_2) + + # Then + assert result == MERGED_EXPECTED + + def test__given_dicts_with_multiple_dates__then_merges_date_entries_correctly(self): + """Given reform dicts with same parameter at multiple dates, + then merges date entries correctly with second taking precedence.""" + # When + result = _merge_reform_dicts(REFORM_DICT_3, REFORM_DICT_4) + + # Then + assert result == MERGED_MULTI_DATE_EXPECTED + + # ========================================================================= + # Given: None values + # ========================================================================= + + def test__given_both_none__then_returns_none(self): + """Given both reform dicts are None, + then returns None.""" + # When + result = _merge_reform_dicts(None, None) + + # Then + assert result is None + + def test__given_first_none__then_returns_second(self): + """Given first reform dict is None, + then returns the second dict.""" + # When + result = _merge_reform_dicts(None, REFORM_DICT_1) + + # Then + assert result == REFORM_DICT_1 + + def test__given_second_none__then_returns_first(self): + """Given second reform dict is None, + then returns the first dict.""" + # When + result = _merge_reform_dicts(REFORM_DICT_1, None) + + # Then + assert result == REFORM_DICT_1 + + # ========================================================================= + # Given: Original dict should not be mutated + # ========================================================================= + + def test__given_two_dicts__then_does_not_mutate_original_dicts(self): + """Given two reform dicts, + then merging does not mutate the original dicts.""" + # Given + original_dict1 = {"param.a": {"2024-01-01": 100}} + original_dict2 = {"param.b": {"2024-01-01": 200}} + dict1_copy = dict(original_dict1) + dict2_copy = dict(original_dict2) + + # When + _merge_reform_dicts(original_dict1, original_dict2) + + # Then + assert original_dict1 == dict1_copy + assert original_dict2 == dict2_copy + + +class TestHouseholdPolicyDataConversion: + """Tests for the policy data conversion logic used in household calculations. + + This tests the conversion logic as it appears in _calculate_household_us + and _calculate_household_uk functions. + """ + + def _convert_policy_data_to_reform(self, policy_data: dict | None) -> dict | None: + """Convert policy_data (from API) to reform dict format. + + This mirrors the conversion logic in _calculate_household_us. + """ + if not policy_data or not policy_data.get("parameter_values"): + return None + + reform = {} + for pv in policy_data["parameter_values"]: + param_name = pv.get("parameter_name") + value = pv.get("value") + start_date = pv.get("start_date") + + if param_name and start_date: + # Parse ISO date string to get just the date part + if "T" in start_date: + date_str = start_date.split("T")[0] + else: + date_str = start_date + + if param_name not in reform: + reform[param_name] = {} + reform[param_name][date_str] = value + + return reform if reform else None + + # ========================================================================= + # Given: Valid policy data from API + # ========================================================================= + + def test__given_valid_policy_data__then_converts_to_reform_dict(self): + """Given valid policy data from the API, + then converts it to the correct reform dict format.""" + # When + result = self._convert_policy_data_to_reform(HOUSEHOLD_POLICY_DATA) + + # Then + assert result == HOUSEHOLD_POLICY_DATA_EXPECTED + + def test__given_policy_data_with_datetime_strings__then_extracts_date_part(self): + """Given policy data with ISO datetime strings (with T and timezone), + then extracts just the date part.""" + # When + result = self._convert_policy_data_to_reform(HOUSEHOLD_POLICY_DATA_DATETIME) + + # Then + assert result == HOUSEHOLD_POLICY_DATA_DATETIME_EXPECTED + + # ========================================================================= + # Given: Empty or None policy data + # ========================================================================= + + def test__given_none_policy_data__then_returns_none(self): + """Given None policy data, + then returns None.""" + # When + result = self._convert_policy_data_to_reform(HOUSEHOLD_NONE_POLICY_DATA) + + # Then + assert result is None + + def test__given_empty_parameter_values__then_returns_none(self): + """Given policy data with empty parameter_values list, + then returns None.""" + # When + result = self._convert_policy_data_to_reform(HOUSEHOLD_EMPTY_POLICY_DATA) + + # Then + assert result is None + + # ========================================================================= + # Given: Incomplete policy data + # ========================================================================= + + def test__given_incomplete_entries__then_skips_invalid_keeps_valid(self): + """Given policy data with some entries missing required fields, + then skips invalid entries and keeps valid ones.""" + # When + result = self._convert_policy_data_to_reform(HOUSEHOLD_INCOMPLETE_POLICY_DATA) + + # Then + assert result == HOUSEHOLD_INCOMPLETE_POLICY_DATA_EXPECTED + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_user_household_associations.py b/tests/test_user_household_associations.py new file mode 100644 index 0000000..25d8989 --- /dev/null +++ b/tests/test_user_household_associations.py @@ -0,0 +1,189 @@ +"""Tests for user-household association endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_user_household_associations import ( + create_association, + create_household, + create_user, +) + +# --------------------------------------------------------------------------- +# POST /user-household-associations +# --------------------------------------------------------------------------- + + +def test_create_association(client, session): + """Create an association returns 201 with id and timestamps.""" + user = create_user(session) + household = create_household(session) + payload = { + "user_id": str(user.id), + "household_id": str(household.id), + "country_id": "us", + "label": "My US household", + } + response = client.post("/user-household-associations", json=payload) + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["user_id"] == str(user.id) + assert data["household_id"] == str(household.id) + assert data["country_id"] == "us" + assert data["label"] == "My US household" + + +def test_create_association_allows_duplicates(client, session): + """Multiple associations to the same household are allowed.""" + user = create_user(session) + household = create_household(session) + payload = { + "user_id": str(user.id), + "household_id": str(household.id), + "country_id": "us", + "label": "First label", + } + r1 = client.post("/user-household-associations", json=payload) + assert r1.status_code == 201 + + payload["label"] = "Second label" + r2 = client.post("/user-household-associations", json=payload) + assert r2.status_code == 201 + assert r1.json()["id"] != r2.json()["id"] + + +def test_create_association_household_not_found(client, session): + """Creating with a non-existent household returns 404.""" + user = create_user(session) + payload = { + "user_id": str(user.id), + "household_id": str(uuid4()), + "country_id": "us", + } + response = client.post("/user-household-associations", json=payload) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# GET /user-household-associations/user/{user_id} +# --------------------------------------------------------------------------- + + +def test_list_by_user_empty(client): + """List associations for a user with none returns empty list.""" + response = client.get(f"/user-household-associations/user/{uuid4()}") + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_by_user(client, session): + """List all associations for a user.""" + user = create_user(session) + h1 = create_household(session, label="H1") + h2 = create_household(session, label="H2") + create_association(session, user.id, h1.id, label="First") + create_association(session, user.id, h2.id, label="Second") + + response = client.get(f"/user-household-associations/user/{user.id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_filter_country(client, session): + """Filter associations by country_id.""" + user = create_user(session) + household = create_household(session) + create_association(session, user.id, household.id, country_id="us") + create_association(session, user.id, household.id, country_id="uk") + + response = client.get( + f"/user-household-associations/user/{user.id}", + params={"country_id": "uk"}, + ) + data = response.json() + assert len(data) == 1 + assert data[0]["country_id"] == "uk" + + +# --------------------------------------------------------------------------- +# GET /user-household-associations/{user_id}/{household_id} +# --------------------------------------------------------------------------- + + +def test_list_by_user_and_household(client, session): + """List associations for a specific user+household pair.""" + user = create_user(session) + household = create_household(session) + create_association(session, user.id, household.id, label="Label A") + create_association(session, user.id, household.id, label="Label B") + + response = client.get(f"/user-household-associations/{user.id}/{household.id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_and_household_empty(client): + """Returns empty list when no associations exist for the pair.""" + response = client.get(f"/user-household-associations/{uuid4()}/{uuid4()}") + assert response.status_code == 200 + assert response.json() == [] + + +# --------------------------------------------------------------------------- +# PUT /user-household-associations/{association_id} +# --------------------------------------------------------------------------- + + +def test_update_association_label(client, session): + """Update label and verify updated_at changes.""" + user = create_user(session) + household = create_household(session) + assoc = create_association(session, user.id, household.id, label="Old") + + response = client.put( + f"/user-household-associations/{assoc.id}", + json={"label": "New label"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["label"] == "New label" + + +def test_update_association_not_found(client): + """Update a non-existent association returns 404.""" + response = client.put( + f"/user-household-associations/{uuid4()}", + json={"label": "Something"}, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# DELETE /user-household-associations/{association_id} +# --------------------------------------------------------------------------- + + +def test_delete_association(client, session): + """Delete an association returns 204.""" + user = create_user(session) + household = create_household(session) + assoc = create_association(session, user.id, household.id) + + response = client.delete(f"/user-household-associations/{assoc.id}") + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/user-household-associations/{user.id}/{household.id}") + assert response.json() == [] + + +def test_delete_association_not_found(client): + """Delete a non-existent association returns 404.""" + response = client.delete(f"/user-household-associations/{uuid4()}") + assert response.status_code == 404 diff --git a/uv.lock b/uv.lock index 094ebf8..466caf4 100644 --- a/uv.lock +++ b/uv.lock @@ -91,6 +91,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "alembic" +version = "1.18.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/41/ab8f624929847b49f84955c594b165855efd829b0c271e1a8cac694138e5/alembic-1.18.3.tar.gz", hash = "sha256:1212aa3778626f2b0f0aa6dd4e99a5f99b94bd25a0c1ac0bba3be65e081e50b0", size = 2052564, upload-time = "2026-01-29T20:24:15.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/8e/d79281f323e7469b060f15bd229e48d7cdd219559e67e71c013720a88340/alembic-1.18.3-py3-none-any.whl", hash = "sha256:12a0359bfc068a4ecbb9b3b02cf77856033abfdb59e4a5aca08b7eacd7b74ddd", size = 262282, upload-time = "2026-01-29T20:24:17.488Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -1057,6 +1071,18 @@ sqlalchemy = [ { name = "opentelemetry-instrumentation-sqlalchemy" }, ] +[[package]] +name = "mako" +version = "1.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -1757,6 +1783,7 @@ name = "policyengine-api-v2" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "alembic" }, { name = "anthropic" }, { name = "boto3" }, { name = "fastapi" }, @@ -1793,6 +1820,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "alembic", specifier = ">=1.13.0" }, { name = "anthropic", specifier = ">=0.40.0" }, { name = "boto3", specifier = ">=1.41.1" }, { name = "fastapi", specifier = ">=0.115.0" },