From 341e92e7134b125a96657d7e0f285021ca61ba06 Mon Sep 17 00:00:00 2001 From: juaristi22 Date: Wed, 28 Jan 2026 18:10:53 +0530 Subject: [PATCH 1/4] simple and hierarchical fallback deduplication --- changelog_entry.yaml | 4 + .../calibration_utils.py | 75 +++ .../fit_calibration_weights.py | 25 +- .../sparse_matrix_builder.py | 282 ++++++++++- .../test_concept_deduplication.py | 439 ++++++++++++++++++ 5 files changed, 811 insertions(+), 14 deletions(-) create mode 100644 policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..e64a6b2a6 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Deduplication logic in SparseMatrixBuilder (option to remove duplicate targets or select most specific geographic level). \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py b/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py index 6a5c415ec..a99727b06 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py @@ -571,3 +571,78 @@ def calculate_spm_thresholds_for_cd( thresholds[i] = base * equiv_scale * geoadj return thresholds + + +def build_concept_id(variable: str, constraints: List[str]) -> str: + """ + Build normalized concept ID from variable + constraints. + + The concept ID uniquely identifies a calibration target "concept" + based on the variable being measured and its non-geographic constraints. + + Args: + variable: Target variable name (e.g., "person_count", "snap") + constraints: List of constraint strings (e.g., ["age>=5", "age<18"]) + + Returns: + Normalized concept ID string + + Examples: + >>> build_concept_id("person_count", ["age>=5", "age<18"]) + 'person_count_age_gte_5_age_lt_18' + >>> build_concept_id("snap", ["snap>0"]) + 'snap_snap_gt_0' + >>> build_concept_id("snap", []) + 'snap' + """ + if not constraints: + return variable + + # Normalize and sort constraints for consistent IDs + normalized = [] + for c in sorted(constraints): + c_norm = ( + c.replace(">=", "_gte_") + .replace("<=", "_lte_") + .replace(">", "_gt_") + .replace("<", "_lt_") + .replace("==", "_eq_") + .replace("=", "_eq_") + .replace(" ", "") + ) + normalized.append(c_norm) + + return f"{variable}_{'_'.join(normalized)}" + + +def extract_constraints_from_row( + row: pd.Series, exclude_geo: bool = True +) -> List[str]: + """ + Extract constraint list from a target row's constraint_info column. + + Args: + row: DataFrame row with 'constraint_info' column containing + pipe-separated constraints (e.g., "age>=5|age<18|state_fips=6") + exclude_geo: If True, filter out geographic constraints + (state_fips, congressional_district_geoid, tax_unit_is_filer) + + Returns: + List of constraint strings like ["age>=5", "age<18"] + """ + if "constraint_info" not in row or pd.isna(row["constraint_info"]): + return [] + + constraints = row["constraint_info"].split("|") + + if exclude_geo: + geo_vars = [ + "state_fips", + "congressional_district_geoid", + "tax_unit_is_filer", + ] + constraints = [ + c for c in constraints if not any(geo in c for geo in geo_vars) + ] + + return constraints diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py b/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py index b22b8eb45..577344bd4 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py @@ -112,11 +112,20 @@ "state_income_tax", # Census STC state income tax collections ], }, + deduplicate=True, + dedup_mode="within_geography", ) -print(f"Matrix shape: {X_sparse.shape}") -print(f"Targets: {len(targets_df)}") +# Print concept and deduplication summaries +builder.print_concept_summary() +builder.print_dedup_summary() +print(f"\nMatrix shape: {X_sparse.shape}") +print(f"Targets after deduplication: {len(targets_df)}") + +# ============================================================================ +# STEP 2: FILTER TO ACHIEVABLE TARGETS +# ============================================================================ # Filter to achievable targets (rows with non-zero data) row_sums = np.array(X_sparse.sum(axis=1)).flatten() achievable_mask = row_sums > 0 @@ -129,7 +138,7 @@ targets_df = targets_df[achievable_mask].reset_index(drop=True) X_sparse = X_sparse[achievable_mask, :] -print(f"Filtered matrix shape: {X_sparse.shape}") +print(f"Final matrix shape: {X_sparse.shape}") # Extract target vector and names targets = targets_df["value"].values @@ -139,14 +148,14 @@ ] # ============================================================================ -# STEP 2: INITIALIZE WEIGHTS +# STEP 3: INITIALIZE WEIGHTS # ============================================================================ initial_weights = np.ones(X_sparse.shape[1]) * 100 print(f"\nInitial weights shape: {initial_weights.shape}") print(f"Initial weights sum: {initial_weights.sum():,.0f}") # ============================================================================ -# STEP 3: CREATE MODEL +# STEP 4: CREATE MODEL # ============================================================================ print("\nCreating SparseCalibrationWeights model...") model = SparseCalibrationWeights( @@ -162,7 +171,7 @@ ) # ============================================================================ -# STEP 4: TRAIN IN CHUNKS +# STEP 5: TRAIN IN CHUNKS # ============================================================================ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") calibration_log = pd.DataFrame() @@ -205,7 +214,7 @@ calibration_log = pd.concat([calibration_log, chunk_df], ignore_index=True) # ============================================================================ -# STEP 5: EXTRACT AND SAVE WEIGHTS +# STEP 6: EXTRACT AND SAVE WEIGHTS # ============================================================================ with torch.no_grad(): w = model.get_weights(deterministic=True).cpu().numpy() @@ -225,7 +234,7 @@ print(f"LOG_PATH:{log_path}") # ============================================================================ -# STEP 6: VERIFY PREDICTIONS +# STEP 7: VERIFY PREDICTIONS # ============================================================================ print("\n" + "=" * 60) print("PREDICTION VERIFICATION") diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py index 5ffe5474a..fcdd8a33f 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py @@ -12,7 +12,8 @@ import numpy as np import pandas as pd from scipy import sparse -from sqlalchemy import create_engine, text +from dataclasses import dataclass +from sqlalchemy import create_engine logger = logging.getLogger(__name__) @@ -20,9 +21,21 @@ get_calculated_variables, apply_op, _get_geo_level, + build_concept_id, + extract_constraints_from_row, ) +@dataclass +class ConceptDuplicateWarning: + """Warning when multiple values exist for the same concept.""" + + concept_id: str + duplicates: List[dict] + selected: dict + reason: str + + class SparseMatrixBuilder: """Build sparse calibration matrices for geo-stacking.""" @@ -40,6 +53,12 @@ def __init__( self.dataset_path = dataset_path self._entity_rel_cache = None + # Populated after build_matrix() with deduplicate=True + self.concept_summary: Optional[pd.DataFrame] = None + self.dedup_warnings: List[ConceptDuplicateWarning] = [] + self.targets_before_dedup: Optional[pd.DataFrame] = None + self.targets_after_dedup: Optional[pd.DataFrame] = None + def _build_entity_relationship(self, sim) -> pd.DataFrame: """ Build entity relationship DataFrame mapping persons to all entity IDs. @@ -198,6 +217,235 @@ def _get_geographic_id(self, stratum_id: int) -> str: return c["value"] return "US" + def _get_constraint_info(self, stratum_id: int) -> str: + """Build pipe-separated constraint string for concept identification.""" + constraints = self._get_constraints(stratum_id) + parts = [] + for c in constraints: + op = "==" if c["operation"] == "=" else c["operation"] + parts.append(f"{c['variable']}{op}{c['value']}") + return "|".join(parts) if parts else None + + def _deduplicate_targets( + self, + targets_df: pd.DataFrame, + mode: str = "within_geography", + priority_column: str = "geo_priority", + ) -> pd.DataFrame: + """ + Deduplicate targets by concept before matrix building. + + Stores results in instance attributes for later inspection: + - self.concept_summary: DataFrame summarizing concepts + - self.dedup_warnings: List of ConceptDuplicateWarning + - self.targets_before_dedup: Original targets DataFrame + - self.targets_after_dedup: Deduplicated targets DataFrame + + Args: + targets_df: DataFrame with target rows including geographic_id + and constraint_info columns + mode: Deduplication mode ("within_geography" or + "hierarchical_fallback") + priority_column: Column to sort by when selecting among + duplicates. Lower values = higher priority. + + Returns: + Deduplicated DataFrame with reset index + """ + df = targets_df.copy() + + # Add geo_priority if not present (CD=1, State=2, National=3) + if priority_column not in df.columns: + df["geo_priority"] = df["geographic_id"].apply( + lambda g: 3 if g == "US" else (1 if int(g) >= 100 else 2) + ) + priority_column = "geo_priority" + + # Build concept_id for each row + df["_concept_id"] = df.apply( + lambda row: build_concept_id( + row["variable"], + extract_constraints_from_row(row, exclude_geo=True), + ), + axis=1, + ) + + # Store concept summary + self.concept_summary = df.groupby("_concept_id").agg( + count=("_concept_id", "size"), + variable=("variable", "first"), + geos=("geographic_id", lambda x: list(x.unique())), + ) + + # Store original for comparison + self.targets_before_dedup = df.copy() + + # Determine deduplication key based on mode + if mode == "within_geography": + if "geographic_id" not in df.columns: + raise ValueError( + "Mode 'within_geography' requires 'geographic_id' column" + ) + dedupe_key = ["_concept_id", "geographic_id"] + elif mode == "hierarchical_fallback": + dedupe_key = ["_concept_id"] + else: + raise ValueError( + f"Unknown mode '{mode}'. Use 'within_geography' or " + "'hierarchical_fallback'" + ) + + # Find and process duplicates + warnings = [] + duplicate_mask = df.duplicated(subset=dedupe_key, keep=False) + duplicates_df = df[duplicate_mask] + + if len(duplicates_df) > 0: + for key_vals, group in duplicates_df.groupby(dedupe_key): + if len(group) <= 1: + continue + + dup_list = [] + for _, dup_row in group.iterrows(): + dup_list.append( + { + "geographic_id": dup_row.get("geographic_id", "?"), + "source": dup_row.get("source_name", "?"), + "period": dup_row.get("period", "?"), + "value": dup_row.get("value", "?"), + "stratum_id": dup_row.get("stratum_id", "?"), + } + ) + + sorted_group = group.sort_values(priority_column) + selected_row = sorted_group.iloc[0] + selected = { + "geographic_id": selected_row.get("geographic_id", "?"), + "source": selected_row.get("source_name", "?"), + "period": selected_row.get("period", "?"), + "value": selected_row.get("value", "?"), + } + + concept_id = ( + key_vals if isinstance(key_vals, str) else key_vals[0] + ) + warnings.append( + ConceptDuplicateWarning( + concept_id=concept_id, + duplicates=dup_list, + selected=selected, + reason=f"Selected by lowest {priority_column}", + ) + ) + + self.dedup_warnings = warnings + + # Deduplicate: sort by key + priority, keep first per key + sort_cols = ( + dedupe_key + [priority_column] + if priority_column in df.columns + else dedupe_key + ) + df_sorted = df.sort_values(sort_cols) + df_deduped = df_sorted.drop_duplicates(subset=dedupe_key, keep="first") + + # Clean up temporary column + df_deduped = df_deduped.drop(columns=["_concept_id"]) + + self.targets_after_dedup = df_deduped.copy() + + return df_deduped.reset_index(drop=True) + + def print_concept_summary(self) -> None: + """ + Print detailed concept summary from the last build_matrix() call. + + Call this after build_matrix() to see what concepts were found. + """ + if self.concept_summary is None: + print("No concept summary available. Run build_matrix() first.") + return + + print("\n" + "=" * 60) + print("CONCEPT SUMMARY") + print("=" * 60) + + n_targets = ( + len(self.targets_before_dedup) + if self.targets_before_dedup is not None + else 0 + ) + print( + f"Found {len(self.concept_summary)} unique concepts " + f"from {n_targets} targets:\n" + ) + + for concept_id, row in self.concept_summary.iterrows(): + n_geos = len(row["geos"]) + print(f" {concept_id}") + print( + f" Variable: {row['variable']}, " + f"Targets: {row['count']}, Geographies: {n_geos}" + ) + + def print_dedup_summary(self) -> None: + """ + Print deduplication summary from the last build_matrix() call. + + Call this after build_matrix() to see what duplicates were removed. + """ + if self.targets_before_dedup is None: + print("No dedup summary available. Run build_matrix() first.") + return + + print("\n" + "=" * 60) + print("DEDUPLICATION SUMMARY") + print("=" * 60) + + before = len(self.targets_before_dedup) + after = ( + len(self.targets_after_dedup) + if self.targets_after_dedup is not None + else 0 + ) + removed = before - after + + print(f"Total targets queried: {before}") + print(f"Targets after deduplication: {after}") + print(f"Duplicates removed: {removed}") + + if self.dedup_warnings: + print(f"\nDuplicate groups resolved ({len(self.dedup_warnings)}):") + for w in self.dedup_warnings: + print(f"\n Concept: {w.concept_id}") + sel_val = w.selected["value"] + sel_val_str = ( + f"{sel_val:,.0f}" + if isinstance(sel_val, (int, float)) + else str(sel_val) + ) + print( + f" Selected: geo={w.selected['geographic_id']}, " + f"value={sel_val_str}" + ) + print(f" Removed ({len(w.duplicates) - 1}):") + for dup in w.duplicates: + if ( + dup["value"] != w.selected["value"] + or dup["geographic_id"] != w.selected["geographic_id"] + ): + dup_val = dup["value"] + dup_val_str = ( + f"{dup_val:,.0f}" + if isinstance(dup_val, (int, float)) + else str(dup_val) + ) + print( + f" - geo={dup['geographic_id']}, " + f"value={dup_val_str}, " + f"source={dup.get('source', '?')}" + ) + def _create_state_sim(self, state: int, n_households: int): """Create a fresh simulation with state_fips set to given state.""" from policyengine_us import Microsimulation @@ -213,19 +461,33 @@ def _create_state_sim(self, state: int, n_households: int): return state_sim def build_matrix( - self, sim, target_filter: dict + self, + sim, + target_filter: dict, + deduplicate: bool = True, + dedup_mode: str = "within_geography", ) -> Tuple[pd.DataFrame, sparse.csr_matrix, Dict[str, List[str]]]: """ Build sparse calibration matrix. Args: - sim: Microsimulation instance (used for household_ids, or as template) + sim: Microsimulation instance (used for household_ids, or + as template) target_filter: Dict specifying which targets to include - {"stratum_group_ids": [4]} for SNAP targets - {"target_ids": [123, 456]} for specific targets + deduplicate: If True, deduplicate targets by concept before + building the matrix (default True) + dedup_mode: Deduplication mode - "within_geography" (default) + removes duplicates with same concept AND geography, or + "hierarchical_fallback" keeps most specific geography + per concept Returns: Tuple of (targets_df, X_sparse, household_id_mapping) + + After calling this method, you can use print_concept_summary() and + print_dedup_summary() to see details about concepts and deduplication. """ household_ids = sim.calculate( "household_id", map_to="household" @@ -235,16 +497,24 @@ def build_matrix( n_cols = n_households * n_cds targets_df = self._query_targets(target_filter) - n_targets = len(targets_df) - if n_targets == 0: + if len(targets_df) == 0: raise ValueError("No targets found matching filter") targets_df["geographic_id"] = targets_df["stratum_id"].apply( self._get_geographic_id ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + self._get_constraint_info + ) + + # Deduplicate targets by concept before building matrix + if deduplicate: + targets_df = self._deduplicate_targets(targets_df, mode=dedup_mode) + + n_targets = len(targets_df) - # Sort by (geo_level, variable, geographic_id) for contiguous group rows + # Sort by (geo_level, variable, geographic_id) for contiguous group targets_df["_geo_level"] = targets_df["geographic_id"].apply( _get_geo_level ) diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py b/policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py new file mode 100644 index 000000000..57fe510e7 --- /dev/null +++ b/policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py @@ -0,0 +1,439 @@ +""" +Tests for concept ID building, constraint extraction, and deduplication. + +These tests verify that: +1. Concept IDs are built consistently from variable + non-geo constraints +2. Constraints are correctly extracted from DataFrame rows +3. Deduplication correctly identifies and removes duplicates via the builder +""" + +import unittest +import tempfile +import os +import pandas as pd +from sqlalchemy import create_engine, text + +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + build_concept_id, + extract_constraints_from_row, +) +from policyengine_us_data.datasets.cps.local_area_calibration.sparse_matrix_builder import ( + SparseMatrixBuilder, +) + + +class TestBuildConceptId(unittest.TestCase): + """Test concept ID building from variable + constraints.""" + + def test_variable_only(self): + """Test concept ID with no constraints.""" + result = build_concept_id("snap", []) + self.assertEqual(result, "snap") + + def test_single_constraint(self): + """Test concept ID with single constraint.""" + result = build_concept_id("snap", ["snap>0"]) + self.assertEqual(result, "snap_snap_gt_0") + + def test_multiple_constraints_sorted(self): + """Test that constraints are sorted for consistency.""" + # Order shouldn't matter - result should be the same + result1 = build_concept_id("person_count", ["age>=5", "age<18"]) + result2 = build_concept_id("person_count", ["age<18", "age>=5"]) + self.assertEqual(result1, result2) + self.assertEqual(result1, "person_count_age_lt_18_age_gte_5") + + def test_operator_normalization(self): + """Test that operators are normalized correctly.""" + self.assertIn("_gte_", build_concept_id("x", ["a>=1"])) + self.assertIn("_lte_", build_concept_id("x", ["a<=1"])) + self.assertIn("_gt_", build_concept_id("x", ["a>1"])) + self.assertIn("_lt_", build_concept_id("x", ["a<1"])) + self.assertIn("_eq_", build_concept_id("x", ["a==1"])) + self.assertIn("_eq_", build_concept_id("x", ["a=1"])) + + def test_spaces_removed(self): + """Test that spaces are removed from constraints.""" + result = build_concept_id("x", ["age >= 5"]) + self.assertNotIn(" ", result) + + +class TestExtractConstraints(unittest.TestCase): + """Test constraint extraction from DataFrame rows.""" + + def test_no_constraint_info(self): + """Test row without constraint_info column.""" + row = pd.Series({"variable": "snap", "value": 1000}) + result = extract_constraints_from_row(row) + self.assertEqual(result, []) + + def test_null_constraint_info(self): + """Test row with null constraint_info.""" + row = pd.Series( + {"variable": "snap", "constraint_info": None, "value": 1000} + ) + result = extract_constraints_from_row(row) + self.assertEqual(result, []) + + def test_single_constraint(self): + """Test row with single constraint.""" + row = pd.Series( + {"variable": "snap", "constraint_info": "snap>0", "value": 1000} + ) + result = extract_constraints_from_row(row) + self.assertEqual(result, ["snap>0"]) + + def test_multiple_constraints(self): + """Test row with pipe-separated constraints.""" + row = pd.Series( + { + "variable": "person_count", + "constraint_info": "age>=5|age<18", + "value": 1000, + } + ) + result = extract_constraints_from_row(row) + self.assertEqual(result, ["age>=5", "age<18"]) + + def test_exclude_geo_constraints(self): + """Test that geographic constraints are excluded by default.""" + row = pd.Series( + { + "variable": "person_count", + "constraint_info": "age>=5|state_fips=6|age<18", + "value": 1000, + } + ) + result = extract_constraints_from_row(row, exclude_geo=True) + self.assertEqual(result, ["age>=5", "age<18"]) + self.assertNotIn("state_fips=6", result) + + def test_include_geo_constraints(self): + """Test that geographic constraints can be included.""" + row = pd.Series( + { + "variable": "person_count", + "constraint_info": "age>=5|state_fips=6", + "value": 1000, + } + ) + result = extract_constraints_from_row(row, exclude_geo=False) + self.assertIn("state_fips=6", result) + + def test_exclude_cd_geoid(self): + """Test that CD geoid constraints are excluded.""" + row = pd.Series( + { + "variable": "snap", + "constraint_info": "snap>0|congressional_district_geoid=601", + "value": 1000, + } + ) + result = extract_constraints_from_row(row, exclude_geo=True) + self.assertEqual(result, ["snap>0"]) + + def test_exclude_filer_constraint(self): + """Test that tax_unit_is_filer constraint is excluded.""" + row = pd.Series( + { + "variable": "income_tax", + "constraint_info": "tax_unit_is_filer=True|income>0", + "value": 1000, + } + ) + result = extract_constraints_from_row(row, exclude_geo=True) + self.assertEqual(result, ["income>0"]) + + +class TestBuilderDeduplication(unittest.TestCase): + """Test deduplication logic through SparseMatrixBuilder.""" + + @classmethod + def setUpClass(cls): + """Create a temporary database with test data.""" + cls.temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + cls.db_path = cls.temp_db.name + cls.temp_db.close() + + cls.db_uri = f"sqlite:///{cls.db_path}" + engine = create_engine(cls.db_uri) + + # Create schema + with engine.connect() as conn: + conn.execute(text(""" + CREATE TABLE stratum_groups ( + stratum_group_id INTEGER PRIMARY KEY, + name TEXT + ) + """)) + conn.execute(text(""" + CREATE TABLE strata ( + stratum_id INTEGER PRIMARY KEY, + stratum_group_id INTEGER + ) + """)) + conn.execute(text(""" + CREATE TABLE stratum_constraints ( + constraint_id INTEGER PRIMARY KEY, + stratum_id INTEGER, + constraint_variable TEXT, + operation TEXT, + value TEXT + ) + """)) + conn.execute(text(""" + CREATE TABLE targets ( + target_id INTEGER PRIMARY KEY, + stratum_id INTEGER, + variable TEXT, + value REAL, + period INTEGER + ) + """)) + conn.commit() + + @classmethod + def tearDownClass(cls): + """Remove temporary database.""" + os.unlink(cls.db_path) + + def setUp(self): + """Clear tables before each test.""" + engine = create_engine(self.db_uri) + with engine.connect() as conn: + conn.execute(text("DELETE FROM targets")) + conn.execute(text("DELETE FROM stratum_constraints")) + conn.execute(text("DELETE FROM strata")) + conn.execute(text("DELETE FROM stratum_groups")) + conn.commit() + + def _insert_test_data(self, strata, constraints, targets): + """Helper to insert test data into database.""" + engine = create_engine(self.db_uri) + with engine.connect() as conn: + # Insert stratum groups + conn.execute( + text("INSERT OR IGNORE INTO stratum_groups VALUES (1, 'test')") + ) + + # Insert strata + for stratum_id, group_id in strata: + conn.execute( + text("INSERT INTO strata VALUES (:sid, :gid)"), + {"sid": stratum_id, "gid": group_id}, + ) + + # Insert constraints + for i, (stratum_id, var, op, val) in enumerate(constraints): + conn.execute( + text(""" + INSERT INTO stratum_constraints + VALUES (:cid, :sid, :var, :op, :val) + """), + { + "cid": i + 1, + "sid": stratum_id, + "var": var, + "op": op, + "val": val, + }, + ) + + # Insert targets + for i, (stratum_id, variable, value, period) in enumerate(targets): + conn.execute( + text(""" + INSERT INTO targets + VALUES (:tid, :sid, :var, :val, :period) + """), + { + "tid": i + 1, + "sid": stratum_id, + "var": variable, + "val": value, + "period": period, + }, + ) + + conn.commit() + + def test_no_duplicates_preserved(self): + """Test that targets with different concepts are all preserved.""" + # Two different variables for the same CD - should NOT deduplicate + self._insert_test_data( + strata=[(1, 1), (2, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (2, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 1000, 2023), + (2, "medicaid", 2000, 2023), + ], + ) + + builder = SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=2023, + cds_to_calibrate=["601"], + ) + + # Call _deduplicate_targets directly with prepared DataFrame + targets_df = builder._query_targets({"stratum_group_ids": [1]}) + targets_df["geographic_id"] = targets_df["stratum_id"].apply( + builder._get_geographic_id + ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + builder._get_constraint_info + ) + + result = builder._deduplicate_targets(targets_df) + + self.assertEqual(len(result), 2) + self.assertEqual(len(builder.dedup_warnings), 0) + + def test_duplicate_same_geo_deduplicated(self): + """Test that same concept at same geography is deduplicated.""" + # Same variable, same CD, different periods - should deduplicate + self._insert_test_data( + strata=[(1, 1), (2, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (2, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 1000, 2023), + (2, "snap", 1100, 2022), # Same concept, same geo + ], + ) + + builder = SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=2023, + cds_to_calibrate=["601"], + ) + + targets_df = builder._query_targets({"stratum_group_ids": [1]}) + targets_df["geographic_id"] = targets_df["stratum_id"].apply( + builder._get_geographic_id + ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + builder._get_constraint_info + ) + + result = builder._deduplicate_targets(targets_df) + + self.assertEqual(len(result), 1) + self.assertEqual(len(builder.dedup_warnings), 1) + + def test_same_concept_different_geos_preserved(self): + """Test that same concept at different geos is NOT deduplicated.""" + # Same variable, different CDs - should NOT deduplicate + self._insert_test_data( + strata=[(1, 1), (2, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (2, "congressional_district_geoid", "=", "602"), + ], + targets=[ + (1, "snap", 1000, 2023), + (2, "snap", 1100, 2023), + ], + ) + + builder = SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=2023, + cds_to_calibrate=["601", "602"], + ) + + targets_df = builder._query_targets({"stratum_group_ids": [1]}) + targets_df["geographic_id"] = targets_df["stratum_id"].apply( + builder._get_geographic_id + ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + builder._get_constraint_info + ) + + result = builder._deduplicate_targets(targets_df) + + self.assertEqual(len(result), 2) # Both kept + self.assertEqual(len(builder.dedup_warnings), 0) + + def test_different_constraints_different_concepts(self): + """Test that different constraints create different concepts.""" + # Same variable but different age constraints - different concepts + self._insert_test_data( + strata=[(1, 1), (2, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (1, "age", ">=", "5"), + (1, "age", "<", "18"), + (2, "congressional_district_geoid", "=", "601"), + (2, "age", ">=", "18"), + (2, "age", "<", "65"), + ], + targets=[ + (1, "person_count", 1000, 2023), + (2, "person_count", 2000, 2023), + ], + ) + + builder = SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=2023, + cds_to_calibrate=["601"], + ) + + targets_df = builder._query_targets({"stratum_group_ids": [1]}) + targets_df["geographic_id"] = targets_df["stratum_id"].apply( + builder._get_geographic_id + ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + builder._get_constraint_info + ) + + result = builder._deduplicate_targets(targets_df) + + self.assertEqual(len(result), 2) # Different concepts + self.assertEqual(len(builder.dedup_warnings), 0) + + def test_hierarchical_fallback_keeps_most_specific(self): + """Test hierarchical fallback mode keeps CD over state over national.""" + # Same concept at CD, state, and national levels + self._insert_test_data( + strata=[(1, 1), (2, 1), (3, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (2, "state_fips", "=", "6"), + # stratum 3 has no geo constraint = national + ], + targets=[ + (1, "snap", 1200000, 2023), # CD level + (2, "snap", 15000000, 2023), # State level + (3, "snap", 110000000000, 2023), # National level + ], + ) + + builder = SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=2023, + cds_to_calibrate=["601"], + ) + + targets_df = builder._query_targets({"stratum_group_ids": [1]}) + targets_df["geographic_id"] = targets_df["stratum_id"].apply( + builder._get_geographic_id + ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + builder._get_constraint_info + ) + + result = builder._deduplicate_targets( + targets_df, mode="hierarchical_fallback" + ) + + self.assertEqual(len(result), 1) + # CD level should be kept (geo_priority=1) + self.assertEqual(result.iloc[0]["geographic_id"], "601") + self.assertEqual(result.iloc[0]["value"], 1200000) From 4033acab36a13a14a708c1bb889b7968b4bb8fbf Mon Sep 17 00:00:00 2001 From: juaristi22 Date: Wed, 28 Jan 2026 19:31:04 +0530 Subject: [PATCH 2/4] entity aware target variable calculations --- changelog_entry.yaml | 3 +- .../sparse_matrix_builder.py | 135 ++++++++++++++++-- 2 files changed, 123 insertions(+), 15 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e64a6b2a6..bfd1545c4 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -1,4 +1,5 @@ - bump: minor changes: added: - - Deduplication logic in SparseMatrixBuilder (option to remove duplicate targets or select most specific geographic level). \ No newline at end of file + - Deduplication logic in SparseMatrixBuilder (option to remove duplicate targets or select most specific geographic level). + - Entity aware target calculations for correct entity counts. \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py index fcdd8a33f..886cd9b0f 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py @@ -157,6 +157,117 @@ def _evaluate_constraints_entity_aware( return household_mask + def _calculate_target_values_entity_aware( + self, + state_sim, + target_variable: str, + non_geo_constraints: List[dict], + geo_mask: np.ndarray, + n_households: int, + ) -> np.ndarray: + """ + Calculate target values at household level, handling count targets. + + For count targets (*_count): Count entities per household satisfying + constraints + For value targets: Sum values at household level (existing behavior) + + Args: + state_sim: Microsimulation with state_fips set + target_variable: The target variable name (e.g., "snap", + "person_count") + non_geo_constraints: List of constraint dicts (geographic + constraints should be pre-filtered) + geo_mask: Boolean mask array for geographic filtering (household + level) + n_households: Number of households + + Returns: + Float array of target values at household level + """ + is_count_target = target_variable.endswith("_count") + + if not is_count_target: + # Value target: use existing entity-aware constraint evaluation + entity_mask = self._evaluate_constraints_entity_aware( + state_sim, non_geo_constraints, n_households + ) + mask = geo_mask & entity_mask + + target_values = state_sim.calculate( + target_variable, map_to="household" + ).values + return (target_values * mask).astype(np.float32) + + # Count target: need to count entities satisfying constraints + entity_rel = self._build_entity_relationship(state_sim) + n_persons = len(entity_rel) + + # Evaluate constraints at person level (don't aggregate to HH yet) + person_mask = np.ones(n_persons, dtype=bool) + for c in non_geo_constraints: + constraint_values = state_sim.calculate( + c["variable"], map_to="person" + ).values + person_mask &= apply_op( + constraint_values, c["operation"], c["value"] + ) + + # Get target entity from variable definition + target_entity = state_sim.tax_benefit_system.variables[ + target_variable + ].entity.key + + household_ids = state_sim.calculate( + "household_id", map_to="household" + ).values + geo_mask_map = dict(zip(household_ids, geo_mask)) + + if target_entity == "household": + # household_count: 1 per qualifying household + if non_geo_constraints: + entity_mask = self._evaluate_constraints_entity_aware( + state_sim, non_geo_constraints, n_households + ) + return (geo_mask & entity_mask).astype(np.float32) + return geo_mask.astype(np.float32) + + if target_entity == "person": + # Count persons satisfying constraints per household + entity_rel["satisfies"] = person_mask + entity_rel["geo_ok"] = entity_rel["household_id"].map(geo_mask_map) + filtered = entity_rel[ + entity_rel["satisfies"] & entity_rel["geo_ok"] + ] + counts = filtered.groupby("household_id")["person_id"].nunique() + else: + # For tax_unit, spm_unit: aggregate person mask to entity, then + # count + entity_id_col = f"{target_entity}_id" + entity_rel["satisfies"] = person_mask + entity_satisfies = entity_rel.groupby(entity_id_col)[ + "satisfies" + ].any() + + entity_rel_unique = entity_rel[ + ["household_id", entity_id_col] + ].drop_duplicates() + entity_rel_unique["entity_ok"] = entity_rel_unique[ + entity_id_col + ].map(entity_satisfies) + entity_rel_unique["geo_ok"] = entity_rel_unique[ + "household_id" + ].map(geo_mask_map) + filtered = entity_rel_unique[ + entity_rel_unique["entity_ok"] & entity_rel_unique["geo_ok"] + ] + counts = filtered.groupby("household_id")[entity_id_col].nunique() + + # Build result aligned with household order + return np.array( + [counts.get(hh_id, 0) for hh_id in household_ids], dtype=np.float32 + ) + def _query_targets(self, target_filter: dict) -> pd.DataFrame: """Query targets based on filter criteria using OR logic.""" or_conditions = [] @@ -586,24 +697,20 @@ def build_matrix( if not geo_mask.any(): continue - # Evaluate non-geographic constraints at entity level - entity_mask = self._evaluate_constraints_entity_aware( - state_sim, non_geo_constraints, n_households + # Calculate target values with entity-aware handling + # This properly handles count targets (*_count) by counting + # entities rather than summing values + masked_values = self._calculate_target_values_entity_aware( + state_sim, + target["variable"], + non_geo_constraints, + geo_mask, + n_households, ) - # Combine geographic and entity-aware masks - mask = geo_mask & entity_mask - - if not mask.any(): + if not masked_values.any(): continue - target_values = state_sim.calculate( - target["variable"], - self.time_period, - map_to="household", - ).values - masked_values = (target_values * mask).astype(np.float32) - nonzero = np.where(masked_values != 0)[0] if len(nonzero) > 0: X[row_idx, col_start + nonzero] = masked_values[ From 9ef84f286a50012c2756a148a5d34fd140b27932 Mon Sep 17 00:00:00 2001 From: juaristi22 Date: Thu, 29 Jan 2026 15:39:22 +0530 Subject: [PATCH 3/4] add count variable tests --- .../sparse_matrix_builder.py | 10 +- .../test_local_area_calibration/conftest.py | 5 + .../test_count_targets.py | 415 ++++++++++++++++++ 3 files changed, 425 insertions(+), 5 deletions(-) create mode 100644 policyengine_us_data/tests/test_local_area_calibration/test_count_targets.py diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py index 886cd9b0f..f22286f4e 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py @@ -289,11 +289,10 @@ def _query_targets(self, target_filter: dict) -> pd.DataFrame: or_conditions.append(f"t.stratum_id IN ({ids})") if not or_conditions: - raise ValueError( - "target_filter must specify at least one filter criterion" - ) - - where_clause = " OR ".join(f"({c})" for c in or_conditions) + # No filter criteria: fetch all targets + where_clause = "1=1" + else: + where_clause = " OR ".join(f"({c})" for c in or_conditions) query = f""" SELECT t.target_id, t.stratum_id, t.variable, t.value, t.period, @@ -587,6 +586,7 @@ def build_matrix( target_filter: Dict specifying which targets to include - {"stratum_group_ids": [4]} for SNAP targets - {"target_ids": [123, 456]} for specific targets + - an empty dict {} will fetch all targets deduplicate: If True, deduplicate targets by concept before building the matrix (default True) dedup_mode: Deduplication mode - "within_geography" (default) diff --git a/policyengine_us_data/tests/test_local_area_calibration/conftest.py b/policyengine_us_data/tests/test_local_area_calibration/conftest.py index 20b7f05bd..d2708906d 100644 --- a/policyengine_us_data/tests/test_local_area_calibration/conftest.py +++ b/policyengine_us_data/tests/test_local_area_calibration/conftest.py @@ -23,6 +23,11 @@ # Format: (variable_name, rtol) # variable_name as per the targets in policy_data.db # rtol is relative tolerance for comparison +# +# NOTE: Count targets (person_count, tax_unit_count) are excluded because +# they have constraints (e.g., age>=5|age<18) that make the X_sparse values +# different from raw sim.calculate() values. Count targets are tested +# separately in test_count_targets.py with controlled mock data. VARIABLES_TO_TEST = [ ("snap", 1e-2), ("income_tax", 1e-2), diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_count_targets.py b/policyengine_us_data/tests/test_local_area_calibration/test_count_targets.py new file mode 100644 index 000000000..46eae4ebb --- /dev/null +++ b/policyengine_us_data/tests/test_local_area_calibration/test_count_targets.py @@ -0,0 +1,415 @@ +""" +Tests for count target handling in SparseMatrixBuilder. + +These tests verify that count targets (e.g., person_count, tax_unit_count) +are correctly handled by counting entities that satisfy constraints, rather +than summing values. +""" + +import pytest +import numpy as np +from dataclasses import dataclass + +from policyengine_us_data.datasets.cps.local_area_calibration.sparse_matrix_builder import ( + SparseMatrixBuilder, +) + + +@dataclass +class MockEntity: + """Mock entity with a key attribute.""" + + key: str + + +@dataclass +class MockVariable: + """Mock variable with entity information.""" + + entity: MockEntity + + @classmethod + def create(cls, entity_key: str) -> "MockVariable": + return cls(entity=MockEntity(key=entity_key)) + + +class MockTaxBenefitSystem: + """Mock tax benefit system with variable definitions.""" + + def __init__(self): + self.variables = { + "person_count": MockVariable.create("person"), + "tax_unit_count": MockVariable.create("tax_unit"), + "household_count": MockVariable.create("household"), + "spm_unit_count": MockVariable.create("spm_unit"), + "snap": MockVariable.create("spm_unit"), + } + + +@dataclass +class MockCalculationResult: + """Mock result from simulation.calculate().""" + + values: np.ndarray + + +class MockSimulation: + """Mock simulation for testing count target calculations.""" + + def __init__(self, entity_data: dict, variable_values: dict): + """ + Args: + entity_data: Dict with person_id, household_id, tax_unit_id, + spm_unit_id arrays (all at person level) + variable_values: Dict mapping variable names to their values + at the appropriate entity level + """ + self.entity_data = entity_data + self.variable_values = variable_values + self.tax_benefit_system = MockTaxBenefitSystem() + + def calculate(self, variable: str, map_to: str = None): + """Return mock calculation result.""" + if variable in self.entity_data: + # Entity ID variables + if map_to == "person": + values = np.array(self.entity_data[variable]) + elif map_to == "household": + # Return unique household IDs + values = np.array( + sorted(set(self.entity_data["household_id"])) + ) + else: + values = np.array(self.entity_data[variable]) + elif variable in self.variable_values: + # Regular variables - return at requested level + val_data = self.variable_values[variable] + if map_to == "person": + values = np.array(val_data["person"]) + elif map_to == "household": + values = np.array(val_data["household"]) + else: + values = np.array(val_data.get("default", [])) + else: + values = np.array([]) + + return MockCalculationResult(values=values) + + +@pytest.fixture +def basic_entity_data(): + """ + Create mock entity relationships with known household compositions. + + Household 1 (id=100): 3 people (ages 5, 12, 40) -> 2 aged 5-17 + Household 2 (id=200): 2 people (ages 3, 25) -> 0 aged 5-17 + Household 3 (id=300): 4 people (ages 6, 8, 10, 45) -> 3 aged 5-17 + """ + return { + "person_id": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "household_id": [100, 100, 100, 200, 200, 300, 300, 300, 300], + "tax_unit_id": [10, 10, 10, 20, 20, 30, 30, 30, 30], + "spm_unit_id": [ + 1000, + 1000, + 1000, + 2000, + 2000, + 3000, + 3000, + 3000, + 3000, + ], + } + + +@pytest.fixture +def basic_variable_values(): + """Variable values for basic household composition tests.""" + return { + "age": { + "person": [5, 12, 40, 3, 25, 6, 8, 10, 45], + "household": [40, 25, 45], # Not used for age constraints + }, + "person_count": { + "person": [1, 1, 1, 1, 1, 1, 1, 1, 1], + "household": [3, 2, 4], # Sum per household + }, + "snap": { + "person": [100, 100, 100, 0, 0, 200, 200, 200, 200], + "household": [300, 0, 800], + }, + } + + +@pytest.fixture +def basic_sim(basic_entity_data, basic_variable_values): + """Mock simulation with basic household compositions.""" + return MockSimulation(basic_entity_data, basic_variable_values) + + +@pytest.fixture +def builder(): + """Create a minimal SparseMatrixBuilder (won't use DB for unit tests).""" + return SparseMatrixBuilder( + db_uri="sqlite:///:memory:", + time_period=2023, + cds_to_calibrate=["101"], + ) + + +# Tests for basic count target calculation +class TestCountTargetCalculation: + """Test _calculate_target_values_entity_aware for count targets.""" + + def test_person_count_with_age_constraints(self, builder, basic_sim): + """Test person_count correctly counts persons in age range per HH.""" + # Constraints: age >= 5 AND age < 18 + constraints = [ + {"variable": "age", "operation": ">=", "value": 5}, + {"variable": "age", "operation": "<", "value": 18}, + ] + + geo_mask = np.array([True, True, True]) # All households included + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 2 people (ages 5, 12), HH2 has 0, HH3 has 3 (6,8,10) + expected = np.array([2, 0, 3], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_no_constraints(self, builder, basic_sim): + """Test person_count without constraints returns all persons per HH.""" + constraints = [] + geo_mask = np.array([True, True, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 3 people, HH2 has 2, HH3 has 4 + expected = np.array([3, 2, 4], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_with_geo_mask(self, builder, basic_sim): + """Test person_count respects geographic mask.""" + constraints = [ + {"variable": "age", "operation": ">=", "value": 5}, + {"variable": "age", "operation": "<", "value": 18}, + ] + + # Only include households 1 and 3 + geo_mask = np.array([True, False, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1=2, HH2=0 (masked out), HH3=3 + expected = np.array([2, 0, 3], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_value_target_uses_sum(self, builder, basic_sim): + """Test that non-count targets sum values (existing behavior).""" + # SNAP is a value target, not a count target + constraints = [] + geo_mask = np.array([True, True, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "snap", + constraints, + geo_mask, + n_households, + ) + + # Expected: Sum of snap values per household + expected = np.array([300, 0, 800], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_household_count_no_constraints(self, builder, basic_sim): + """Test household_count returns 1 for each qualifying household.""" + constraints = [] + geo_mask = np.array([True, True, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "household_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: 1 for each household in geo_mask + expected = np.array([1, 1, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_household_count_with_geo_mask(self, builder, basic_sim): + """Test household_count respects geographic mask.""" + constraints = [] + geo_mask = np.array([True, False, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "household_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: 1 for HH1, 0 for HH2 (masked), 1 for HH3 + expected = np.array([1, 0, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + +# Fixtures for complex entity relationship tests +@pytest.fixture +def complex_entity_data(): + """ + Create entity data with multiple tax units per household. + + Household 1 (id=100): 4 people in 2 tax units + Tax unit 10: person 1 (age 30, filer), person 2 (age 28) + Tax unit 11: person 3 (age 65, filer), person 4 (age 62) + Household 2 (id=200): 2 people in 1 tax unit + Tax unit 20: person 5 (age 45, filer), person 6 (age 16) + """ + return { + "person_id": [1, 2, 3, 4, 5, 6], + "household_id": [100, 100, 100, 100, 200, 200], + "tax_unit_id": [10, 10, 11, 11, 20, 20], + "spm_unit_id": [1000, 1000, 1000, 1000, 2000, 2000], + } + + +@pytest.fixture +def complex_variable_values(): + """Variable values for complex entity relationship tests.""" + return { + "age": { + "person": [30, 28, 65, 62, 45, 16], + "household": [65, 45], + }, + "is_tax_unit_head": { + "person": [True, False, True, False, True, False], + "household": [2, 1], # count of heads per HH + }, + "tax_unit_count": { + "person": [1, 1, 1, 1, 1, 1], + "household": [2, 1], + }, + "person_count": { + "person": [1, 1, 1, 1, 1, 1], + "household": [4, 2], + }, + } + + +@pytest.fixture +def complex_sim(complex_entity_data, complex_variable_values): + """Mock simulation with complex entity relationships.""" + return MockSimulation(complex_entity_data, complex_variable_values) + + +# Tests for complex entity relationships +class TestCountTargetWithRealEntities: + """Test count targets with more complex entity relationships.""" + + def test_tax_unit_count_no_constraints(self, builder, complex_sim): + """Test tax_unit_count counts all tax units per household.""" + constraints = [] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "tax_unit_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 2 tax units, HH2 has 1 + expected = np.array([2, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_tax_unit_count_with_age_constraint(self, builder, complex_sim): + """Test tax_unit_count with age constraint on members.""" + # Count tax units that have at least one person aged >= 65 + constraints = [ + {"variable": "age", "operation": ">=", "value": 65}, + ] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "tax_unit_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 1 tax unit (TU 11) with person >=65, HH2 has 0 + expected = np.array([1, 0], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_seniors(self, builder, complex_sim): + """Test person_count for seniors (age >= 65).""" + constraints = [ + {"variable": "age", "operation": ">=", "value": 65}, + ] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 1 senior (age 65), HH2 has 0 + expected = np.array([1, 0], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_children(self, builder, complex_sim): + """Test person_count for children (age < 18).""" + constraints = [ + {"variable": "age", "operation": "<", "value": 18}, + ] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 0 children, HH2 has 1 (age 16) + expected = np.array([0, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) From 77b863583eb743752f2b8a07ee22f051a86a0940 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Mon, 9 Feb 2026 14:59:44 -0500 Subject: [PATCH 4/4] Replace dedup machinery with period selection + uprating Remove ~200 lines of concept deduplication code (ConceptDuplicateWarning, _deduplicate_targets, build_concept_id, extract_constraints_from_row) and replace with SQL CTE that selects the best period per (stratum_id, variable) and uprating that adjusts non-current-period values using CPI or population factors from PolicyEngine parameters. Co-Authored-By: Claude Opus 4.6 --- .../calibration_utils.py | 75 --- .../fit_calibration_weights.py | 8 +- .../sparse_matrix_builder.py | 383 +++++---------- .../test_concept_deduplication.py | 439 ------------------ .../test_period_selection_and_uprating.py | 256 ++++++++++ 5 files changed, 385 insertions(+), 776 deletions(-) delete mode 100644 policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py create mode 100644 policyengine_us_data/tests/test_local_area_calibration/test_period_selection_and_uprating.py diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py b/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py index a99727b06..6a5c415ec 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py @@ -571,78 +571,3 @@ def calculate_spm_thresholds_for_cd( thresholds[i] = base * equiv_scale * geoadj return thresholds - - -def build_concept_id(variable: str, constraints: List[str]) -> str: - """ - Build normalized concept ID from variable + constraints. - - The concept ID uniquely identifies a calibration target "concept" - based on the variable being measured and its non-geographic constraints. - - Args: - variable: Target variable name (e.g., "person_count", "snap") - constraints: List of constraint strings (e.g., ["age>=5", "age<18"]) - - Returns: - Normalized concept ID string - - Examples: - >>> build_concept_id("person_count", ["age>=5", "age<18"]) - 'person_count_age_gte_5_age_lt_18' - >>> build_concept_id("snap", ["snap>0"]) - 'snap_snap_gt_0' - >>> build_concept_id("snap", []) - 'snap' - """ - if not constraints: - return variable - - # Normalize and sort constraints for consistent IDs - normalized = [] - for c in sorted(constraints): - c_norm = ( - c.replace(">=", "_gte_") - .replace("<=", "_lte_") - .replace(">", "_gt_") - .replace("<", "_lt_") - .replace("==", "_eq_") - .replace("=", "_eq_") - .replace(" ", "") - ) - normalized.append(c_norm) - - return f"{variable}_{'_'.join(normalized)}" - - -def extract_constraints_from_row( - row: pd.Series, exclude_geo: bool = True -) -> List[str]: - """ - Extract constraint list from a target row's constraint_info column. - - Args: - row: DataFrame row with 'constraint_info' column containing - pipe-separated constraints (e.g., "age>=5|age<18|state_fips=6") - exclude_geo: If True, filter out geographic constraints - (state_fips, congressional_district_geoid, tax_unit_is_filer) - - Returns: - List of constraint strings like ["age>=5", "age<18"] - """ - if "constraint_info" not in row or pd.isna(row["constraint_info"]): - return [] - - constraints = row["constraint_info"].split("|") - - if exclude_geo: - geo_vars = [ - "state_fips", - "congressional_district_geoid", - "tax_unit_is_filer", - ] - constraints = [ - c for c in constraints if not any(geo in c for geo in geo_vars) - ] - - return constraints diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py b/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py index 577344bd4..d01519a98 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py @@ -112,16 +112,12 @@ "state_income_tax", # Census STC state income tax collections ], }, - deduplicate=True, - dedup_mode="within_geography", ) -# Print concept and deduplication summaries -builder.print_concept_summary() -builder.print_dedup_summary() +builder.print_uprating_summary(targets_df) print(f"\nMatrix shape: {X_sparse.shape}") -print(f"Targets after deduplication: {len(targets_df)}") +print(f"Targets: {len(targets_df)}") # ============================================================================ # STEP 2: FILTER TO ACHIEVABLE TARGETS diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py index f22286f4e..5d2210c12 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py @@ -12,8 +12,7 @@ import numpy as np import pandas as pd from scipy import sparse -from dataclasses import dataclass -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text logger = logging.getLogger(__name__) @@ -21,21 +20,9 @@ get_calculated_variables, apply_op, _get_geo_level, - build_concept_id, - extract_constraints_from_row, ) -@dataclass -class ConceptDuplicateWarning: - """Warning when multiple values exist for the same concept.""" - - concept_id: str - duplicates: List[dict] - selected: dict - reason: str - - class SparseMatrixBuilder: """Build sparse calibration matrices for geo-stacking.""" @@ -53,12 +40,6 @@ def __init__( self.dataset_path = dataset_path self._entity_rel_cache = None - # Populated after build_matrix() with deduplicate=True - self.concept_summary: Optional[pd.DataFrame] = None - self.dedup_warnings: List[ConceptDuplicateWarning] = [] - self.targets_before_dedup: Optional[pd.DataFrame] = None - self.targets_after_dedup: Optional[pd.DataFrame] = None - def _build_entity_relationship(self, sim) -> pd.DataFrame: """ Build entity relationship DataFrame mapping persons to all entity IDs. @@ -269,7 +250,11 @@ def _calculate_target_values_entity_aware( ) def _query_targets(self, target_filter: dict) -> pd.DataFrame: - """Query targets based on filter criteria using OR logic.""" + """Query targets, selecting best period per (stratum_id, variable). + + Best period: most recent period <= self.time_period, or closest + future period if none exists. + """ or_conditions = [] if "stratum_group_ids" in target_filter: @@ -289,22 +274,43 @@ def _query_targets(self, target_filter: dict) -> pd.DataFrame: or_conditions.append(f"t.stratum_id IN ({ids})") if not or_conditions: - # No filter criteria: fetch all targets where_clause = "1=1" else: where_clause = " OR ".join(f"({c})" for c in or_conditions) query = f""" - SELECT t.target_id, t.stratum_id, t.variable, t.value, t.period, - s.stratum_group_id - FROM targets t - JOIN strata s ON t.stratum_id = s.stratum_id - WHERE {where_clause} - ORDER BY t.target_id + WITH filtered_targets AS ( + SELECT t.target_id, t.stratum_id, t.variable, t.value, + t.period, s.stratum_group_id + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + WHERE {where_clause} + ), + best_periods AS ( + SELECT stratum_id, variable, + CASE + WHEN MAX(CASE WHEN period <= :time_period + THEN period END) IS NOT NULL + THEN MAX(CASE WHEN period <= :time_period + THEN period END) + ELSE MIN(period) + END as best_period + FROM filtered_targets + GROUP BY stratum_id, variable + ) + SELECT ft.* + FROM filtered_targets ft + JOIN best_periods bp + ON ft.stratum_id = bp.stratum_id + AND ft.variable = bp.variable + AND ft.period = bp.best_period + ORDER BY ft.target_id """ with self.engine.connect() as conn: - return pd.read_sql(query, conn) + return pd.read_sql( + query, conn, params={"time_period": self.time_period} + ) def _get_constraints(self, stratum_id: int) -> List[dict]: """Get all constraints for a stratum (including geographic).""" @@ -327,234 +333,103 @@ def _get_geographic_id(self, stratum_id: int) -> str: return c["value"] return "US" - def _get_constraint_info(self, stratum_id: int) -> str: - """Build pipe-separated constraint string for concept identification.""" - constraints = self._get_constraints(stratum_id) - parts = [] - for c in constraints: - op = "==" if c["operation"] == "=" else c["operation"] - parts.append(f"{c['variable']}{op}{c['value']}") - return "|".join(parts) if parts else None - - def _deduplicate_targets( - self, - targets_df: pd.DataFrame, - mode: str = "within_geography", - priority_column: str = "geo_priority", - ) -> pd.DataFrame: - """ - Deduplicate targets by concept before matrix building. - - Stores results in instance attributes for later inspection: - - self.concept_summary: DataFrame summarizing concepts - - self.dedup_warnings: List of ConceptDuplicateWarning - - self.targets_before_dedup: Original targets DataFrame - - self.targets_after_dedup: Deduplicated targets DataFrame - - Args: - targets_df: DataFrame with target rows including geographic_id - and constraint_info columns - mode: Deduplication mode ("within_geography" or - "hierarchical_fallback") - priority_column: Column to sort by when selecting among - duplicates. Lower values = higher priority. - - Returns: - Deduplicated DataFrame with reset index - """ - df = targets_df.copy() + def _calculate_uprating_factors(self, params) -> dict: + """Calculate CPI and population uprating factors for all periods.""" + factors = {} - # Add geo_priority if not present (CD=1, State=2, National=3) - if priority_column not in df.columns: - df["geo_priority"] = df["geographic_id"].apply( - lambda g: 3 if g == "US" else (1 if int(g) >= 100 else 2) - ) - priority_column = "geo_priority" - - # Build concept_id for each row - df["_concept_id"] = df.apply( - lambda row: build_concept_id( - row["variable"], - extract_constraints_from_row(row, exclude_geo=True), - ), - axis=1, - ) + query = "SELECT DISTINCT period FROM targets WHERE period IS NOT NULL ORDER BY period" + with self.engine.connect() as conn: + result = conn.execute(text(query)) + years_needed = [row[0] for row in result] - # Store concept summary - self.concept_summary = df.groupby("_concept_id").agg( - count=("_concept_id", "size"), - variable=("variable", "first"), - geos=("geographic_id", lambda x: list(x.unique())), + logger.info( + f"Calculating uprating factors for years " + f"{years_needed} to {self.time_period}" ) - # Store original for comparison - self.targets_before_dedup = df.copy() - - # Determine deduplication key based on mode - if mode == "within_geography": - if "geographic_id" not in df.columns: - raise ValueError( - "Mode 'within_geography' requires 'geographic_id' column" + for from_year in years_needed: + if from_year == self.time_period: + factors[(from_year, "cpi")] = 1.0 + factors[(from_year, "pop")] = 1.0 + continue + + try: + cpi_from = params.gov.bls.cpi.cpi_u(from_year) + cpi_to = params.gov.bls.cpi.cpi_u(self.time_period) + factors[(from_year, "cpi")] = float(cpi_to / cpi_from) + except Exception as e: + logger.warning( + f"Could not calculate CPI factor for " f"{from_year}: {e}" ) - dedupe_key = ["_concept_id", "geographic_id"] - elif mode == "hierarchical_fallback": - dedupe_key = ["_concept_id"] - else: - raise ValueError( - f"Unknown mode '{mode}'. Use 'within_geography' or " - "'hierarchical_fallback'" - ) + factors[(from_year, "cpi")] = 1.0 - # Find and process duplicates - warnings = [] - duplicate_mask = df.duplicated(subset=dedupe_key, keep=False) - duplicates_df = df[duplicate_mask] - - if len(duplicates_df) > 0: - for key_vals, group in duplicates_df.groupby(dedupe_key): - if len(group) <= 1: - continue - - dup_list = [] - for _, dup_row in group.iterrows(): - dup_list.append( - { - "geographic_id": dup_row.get("geographic_id", "?"), - "source": dup_row.get("source_name", "?"), - "period": dup_row.get("period", "?"), - "value": dup_row.get("value", "?"), - "stratum_id": dup_row.get("stratum_id", "?"), - } - ) - - sorted_group = group.sort_values(priority_column) - selected_row = sorted_group.iloc[0] - selected = { - "geographic_id": selected_row.get("geographic_id", "?"), - "source": selected_row.get("source_name", "?"), - "period": selected_row.get("period", "?"), - "value": selected_row.get("value", "?"), - } - - concept_id = ( - key_vals if isinstance(key_vals, str) else key_vals[0] + try: + pop_from = params.calibration.gov.census.populations.total( + from_year ) - warnings.append( - ConceptDuplicateWarning( - concept_id=concept_id, - duplicates=dup_list, - selected=selected, - reason=f"Selected by lowest {priority_column}", - ) + pop_to = params.calibration.gov.census.populations.total( + self.time_period ) + factors[(from_year, "pop")] = float(pop_to / pop_from) + except Exception as e: + logger.warning( + f"Could not calculate population factor for " + f"{from_year}: {e}" + ) + factors[(from_year, "pop")] = 1.0 - self.dedup_warnings = warnings - - # Deduplicate: sort by key + priority, keep first per key - sort_cols = ( - dedupe_key + [priority_column] - if priority_column in df.columns - else dedupe_key - ) - df_sorted = df.sort_values(sort_cols) - df_deduped = df_sorted.drop_duplicates(subset=dedupe_key, keep="first") - - # Clean up temporary column - df_deduped = df_deduped.drop(columns=["_concept_id"]) - - self.targets_after_dedup = df_deduped.copy() - - return df_deduped.reset_index(drop=True) + for (year, type_), factor in sorted(factors.items()): + if factor != 1.0: + logger.info( + f" {year} -> {self.time_period} " + f"({type_}): {factor:.4f}" + ) - def print_concept_summary(self) -> None: - """ - Print detailed concept summary from the last build_matrix() call. + return factors - Call this after build_matrix() to see what concepts were found. - """ - if self.concept_summary is None: - print("No concept summary available. Run build_matrix() first.") + def _get_uprating_info( + self, + variable: str, + period: int, + factors: dict, + ) -> Tuple[float, str]: + """Get uprating factor and type for a variable at a given period.""" + if period == self.time_period: + return 1.0, "none" + + count_indicators = [ + "count", + "person", + "people", + "households", + "tax_units", + ] + is_count = any(ind in variable.lower() for ind in count_indicators) + uprating_type = "pop" if is_count else "cpi" + + factor = factors.get((period, uprating_type), 1.0) + return factor, uprating_type + + def print_uprating_summary(self, targets_df: pd.DataFrame) -> None: + """Print summary of uprating applied to targets.""" + uprated = targets_df[targets_df["uprating_factor"] != 1.0] + if len(uprated) == 0: + print("No targets were uprated.") return print("\n" + "=" * 60) - print("CONCEPT SUMMARY") + print("UPRATING SUMMARY") print("=" * 60) + print(f"Uprated {len(uprated)} of {len(targets_df)} targets") - n_targets = ( - len(self.targets_before_dedup) - if self.targets_before_dedup is not None - else 0 - ) - print( - f"Found {len(self.concept_summary)} unique concepts " - f"from {n_targets} targets:\n" - ) - - for concept_id, row in self.concept_summary.iterrows(): - n_geos = len(row["geos"]) - print(f" {concept_id}") - print( - f" Variable: {row['variable']}, " - f"Targets: {row['count']}, Geographies: {n_geos}" - ) - - def print_dedup_summary(self) -> None: - """ - Print deduplication summary from the last build_matrix() call. - - Call this after build_matrix() to see what duplicates were removed. - """ - if self.targets_before_dedup is None: - print("No dedup summary available. Run build_matrix() first.") - return + period_counts = uprated["period"].value_counts().sort_index() + for period, count in period_counts.items(): + print(f" Period {period}: {count} targets") - print("\n" + "=" * 60) - print("DEDUPLICATION SUMMARY") - print("=" * 60) - - before = len(self.targets_before_dedup) - after = ( - len(self.targets_after_dedup) - if self.targets_after_dedup is not None - else 0 + factors = uprated["uprating_factor"] + print( + f" Factor range: [{factors.min():.4f}, " f"{factors.max():.4f}]" ) - removed = before - after - - print(f"Total targets queried: {before}") - print(f"Targets after deduplication: {after}") - print(f"Duplicates removed: {removed}") - - if self.dedup_warnings: - print(f"\nDuplicate groups resolved ({len(self.dedup_warnings)}):") - for w in self.dedup_warnings: - print(f"\n Concept: {w.concept_id}") - sel_val = w.selected["value"] - sel_val_str = ( - f"{sel_val:,.0f}" - if isinstance(sel_val, (int, float)) - else str(sel_val) - ) - print( - f" Selected: geo={w.selected['geographic_id']}, " - f"value={sel_val_str}" - ) - print(f" Removed ({len(w.duplicates) - 1}):") - for dup in w.duplicates: - if ( - dup["value"] != w.selected["value"] - or dup["geographic_id"] != w.selected["geographic_id"] - ): - dup_val = dup["value"] - dup_val_str = ( - f"{dup_val:,.0f}" - if isinstance(dup_val, (int, float)) - else str(dup_val) - ) - print( - f" - geo={dup['geographic_id']}, " - f"value={dup_val_str}, " - f"source={dup.get('source', '?')}" - ) def _create_state_sim(self, state: int, n_households: int): """Create a fresh simulation with state_fips set to given state.""" @@ -574,8 +449,6 @@ def build_matrix( self, sim, target_filter: dict, - deduplicate: bool = True, - dedup_mode: str = "within_geography", ) -> Tuple[pd.DataFrame, sparse.csr_matrix, Dict[str, List[str]]]: """ Build sparse calibration matrix. @@ -587,18 +460,9 @@ def build_matrix( - {"stratum_group_ids": [4]} for SNAP targets - {"target_ids": [123, 456]} for specific targets - an empty dict {} will fetch all targets - deduplicate: If True, deduplicate targets by concept before - building the matrix (default True) - dedup_mode: Deduplication mode - "within_geography" (default) - removes duplicates with same concept AND geography, or - "hierarchical_fallback" keeps most specific geography - per concept Returns: Tuple of (targets_df, X_sparse, household_id_mapping) - - After calling this method, you can use print_concept_summary() and - print_dedup_summary() to see details about concepts and deduplication. """ household_ids = sim.calculate( "household_id", map_to="household" @@ -615,13 +479,20 @@ def build_matrix( targets_df["geographic_id"] = targets_df["stratum_id"].apply( self._get_geographic_id ) - targets_df["constraint_info"] = targets_df["stratum_id"].apply( - self._get_constraint_info - ) - # Deduplicate targets by concept before building matrix - if deduplicate: - targets_df = self._deduplicate_targets(targets_df, mode=dedup_mode) + # Uprate targets from their original period to self.time_period + params = sim.tax_benefit_system.parameters + uprating_factors = self._calculate_uprating_factors(params) + targets_df["original_value"] = targets_df["value"].copy() + targets_df["uprating_factor"] = targets_df.apply( + lambda row: self._get_uprating_info( + row["variable"], row["period"], uprating_factors + )[0], + axis=1, + ) + targets_df["value"] = ( + targets_df["original_value"] * targets_df["uprating_factor"] + ) n_targets = len(targets_df) diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py b/policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py deleted file mode 100644 index 57fe510e7..000000000 --- a/policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py +++ /dev/null @@ -1,439 +0,0 @@ -""" -Tests for concept ID building, constraint extraction, and deduplication. - -These tests verify that: -1. Concept IDs are built consistently from variable + non-geo constraints -2. Constraints are correctly extracted from DataFrame rows -3. Deduplication correctly identifies and removes duplicates via the builder -""" - -import unittest -import tempfile -import os -import pandas as pd -from sqlalchemy import create_engine, text - -from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( - build_concept_id, - extract_constraints_from_row, -) -from policyengine_us_data.datasets.cps.local_area_calibration.sparse_matrix_builder import ( - SparseMatrixBuilder, -) - - -class TestBuildConceptId(unittest.TestCase): - """Test concept ID building from variable + constraints.""" - - def test_variable_only(self): - """Test concept ID with no constraints.""" - result = build_concept_id("snap", []) - self.assertEqual(result, "snap") - - def test_single_constraint(self): - """Test concept ID with single constraint.""" - result = build_concept_id("snap", ["snap>0"]) - self.assertEqual(result, "snap_snap_gt_0") - - def test_multiple_constraints_sorted(self): - """Test that constraints are sorted for consistency.""" - # Order shouldn't matter - result should be the same - result1 = build_concept_id("person_count", ["age>=5", "age<18"]) - result2 = build_concept_id("person_count", ["age<18", "age>=5"]) - self.assertEqual(result1, result2) - self.assertEqual(result1, "person_count_age_lt_18_age_gte_5") - - def test_operator_normalization(self): - """Test that operators are normalized correctly.""" - self.assertIn("_gte_", build_concept_id("x", ["a>=1"])) - self.assertIn("_lte_", build_concept_id("x", ["a<=1"])) - self.assertIn("_gt_", build_concept_id("x", ["a>1"])) - self.assertIn("_lt_", build_concept_id("x", ["a<1"])) - self.assertIn("_eq_", build_concept_id("x", ["a==1"])) - self.assertIn("_eq_", build_concept_id("x", ["a=1"])) - - def test_spaces_removed(self): - """Test that spaces are removed from constraints.""" - result = build_concept_id("x", ["age >= 5"]) - self.assertNotIn(" ", result) - - -class TestExtractConstraints(unittest.TestCase): - """Test constraint extraction from DataFrame rows.""" - - def test_no_constraint_info(self): - """Test row without constraint_info column.""" - row = pd.Series({"variable": "snap", "value": 1000}) - result = extract_constraints_from_row(row) - self.assertEqual(result, []) - - def test_null_constraint_info(self): - """Test row with null constraint_info.""" - row = pd.Series( - {"variable": "snap", "constraint_info": None, "value": 1000} - ) - result = extract_constraints_from_row(row) - self.assertEqual(result, []) - - def test_single_constraint(self): - """Test row with single constraint.""" - row = pd.Series( - {"variable": "snap", "constraint_info": "snap>0", "value": 1000} - ) - result = extract_constraints_from_row(row) - self.assertEqual(result, ["snap>0"]) - - def test_multiple_constraints(self): - """Test row with pipe-separated constraints.""" - row = pd.Series( - { - "variable": "person_count", - "constraint_info": "age>=5|age<18", - "value": 1000, - } - ) - result = extract_constraints_from_row(row) - self.assertEqual(result, ["age>=5", "age<18"]) - - def test_exclude_geo_constraints(self): - """Test that geographic constraints are excluded by default.""" - row = pd.Series( - { - "variable": "person_count", - "constraint_info": "age>=5|state_fips=6|age<18", - "value": 1000, - } - ) - result = extract_constraints_from_row(row, exclude_geo=True) - self.assertEqual(result, ["age>=5", "age<18"]) - self.assertNotIn("state_fips=6", result) - - def test_include_geo_constraints(self): - """Test that geographic constraints can be included.""" - row = pd.Series( - { - "variable": "person_count", - "constraint_info": "age>=5|state_fips=6", - "value": 1000, - } - ) - result = extract_constraints_from_row(row, exclude_geo=False) - self.assertIn("state_fips=6", result) - - def test_exclude_cd_geoid(self): - """Test that CD geoid constraints are excluded.""" - row = pd.Series( - { - "variable": "snap", - "constraint_info": "snap>0|congressional_district_geoid=601", - "value": 1000, - } - ) - result = extract_constraints_from_row(row, exclude_geo=True) - self.assertEqual(result, ["snap>0"]) - - def test_exclude_filer_constraint(self): - """Test that tax_unit_is_filer constraint is excluded.""" - row = pd.Series( - { - "variable": "income_tax", - "constraint_info": "tax_unit_is_filer=True|income>0", - "value": 1000, - } - ) - result = extract_constraints_from_row(row, exclude_geo=True) - self.assertEqual(result, ["income>0"]) - - -class TestBuilderDeduplication(unittest.TestCase): - """Test deduplication logic through SparseMatrixBuilder.""" - - @classmethod - def setUpClass(cls): - """Create a temporary database with test data.""" - cls.temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) - cls.db_path = cls.temp_db.name - cls.temp_db.close() - - cls.db_uri = f"sqlite:///{cls.db_path}" - engine = create_engine(cls.db_uri) - - # Create schema - with engine.connect() as conn: - conn.execute(text(""" - CREATE TABLE stratum_groups ( - stratum_group_id INTEGER PRIMARY KEY, - name TEXT - ) - """)) - conn.execute(text(""" - CREATE TABLE strata ( - stratum_id INTEGER PRIMARY KEY, - stratum_group_id INTEGER - ) - """)) - conn.execute(text(""" - CREATE TABLE stratum_constraints ( - constraint_id INTEGER PRIMARY KEY, - stratum_id INTEGER, - constraint_variable TEXT, - operation TEXT, - value TEXT - ) - """)) - conn.execute(text(""" - CREATE TABLE targets ( - target_id INTEGER PRIMARY KEY, - stratum_id INTEGER, - variable TEXT, - value REAL, - period INTEGER - ) - """)) - conn.commit() - - @classmethod - def tearDownClass(cls): - """Remove temporary database.""" - os.unlink(cls.db_path) - - def setUp(self): - """Clear tables before each test.""" - engine = create_engine(self.db_uri) - with engine.connect() as conn: - conn.execute(text("DELETE FROM targets")) - conn.execute(text("DELETE FROM stratum_constraints")) - conn.execute(text("DELETE FROM strata")) - conn.execute(text("DELETE FROM stratum_groups")) - conn.commit() - - def _insert_test_data(self, strata, constraints, targets): - """Helper to insert test data into database.""" - engine = create_engine(self.db_uri) - with engine.connect() as conn: - # Insert stratum groups - conn.execute( - text("INSERT OR IGNORE INTO stratum_groups VALUES (1, 'test')") - ) - - # Insert strata - for stratum_id, group_id in strata: - conn.execute( - text("INSERT INTO strata VALUES (:sid, :gid)"), - {"sid": stratum_id, "gid": group_id}, - ) - - # Insert constraints - for i, (stratum_id, var, op, val) in enumerate(constraints): - conn.execute( - text(""" - INSERT INTO stratum_constraints - VALUES (:cid, :sid, :var, :op, :val) - """), - { - "cid": i + 1, - "sid": stratum_id, - "var": var, - "op": op, - "val": val, - }, - ) - - # Insert targets - for i, (stratum_id, variable, value, period) in enumerate(targets): - conn.execute( - text(""" - INSERT INTO targets - VALUES (:tid, :sid, :var, :val, :period) - """), - { - "tid": i + 1, - "sid": stratum_id, - "var": variable, - "val": value, - "period": period, - }, - ) - - conn.commit() - - def test_no_duplicates_preserved(self): - """Test that targets with different concepts are all preserved.""" - # Two different variables for the same CD - should NOT deduplicate - self._insert_test_data( - strata=[(1, 1), (2, 1)], - constraints=[ - (1, "congressional_district_geoid", "=", "601"), - (2, "congressional_district_geoid", "=", "601"), - ], - targets=[ - (1, "snap", 1000, 2023), - (2, "medicaid", 2000, 2023), - ], - ) - - builder = SparseMatrixBuilder( - db_uri=self.db_uri, - time_period=2023, - cds_to_calibrate=["601"], - ) - - # Call _deduplicate_targets directly with prepared DataFrame - targets_df = builder._query_targets({"stratum_group_ids": [1]}) - targets_df["geographic_id"] = targets_df["stratum_id"].apply( - builder._get_geographic_id - ) - targets_df["constraint_info"] = targets_df["stratum_id"].apply( - builder._get_constraint_info - ) - - result = builder._deduplicate_targets(targets_df) - - self.assertEqual(len(result), 2) - self.assertEqual(len(builder.dedup_warnings), 0) - - def test_duplicate_same_geo_deduplicated(self): - """Test that same concept at same geography is deduplicated.""" - # Same variable, same CD, different periods - should deduplicate - self._insert_test_data( - strata=[(1, 1), (2, 1)], - constraints=[ - (1, "congressional_district_geoid", "=", "601"), - (2, "congressional_district_geoid", "=", "601"), - ], - targets=[ - (1, "snap", 1000, 2023), - (2, "snap", 1100, 2022), # Same concept, same geo - ], - ) - - builder = SparseMatrixBuilder( - db_uri=self.db_uri, - time_period=2023, - cds_to_calibrate=["601"], - ) - - targets_df = builder._query_targets({"stratum_group_ids": [1]}) - targets_df["geographic_id"] = targets_df["stratum_id"].apply( - builder._get_geographic_id - ) - targets_df["constraint_info"] = targets_df["stratum_id"].apply( - builder._get_constraint_info - ) - - result = builder._deduplicate_targets(targets_df) - - self.assertEqual(len(result), 1) - self.assertEqual(len(builder.dedup_warnings), 1) - - def test_same_concept_different_geos_preserved(self): - """Test that same concept at different geos is NOT deduplicated.""" - # Same variable, different CDs - should NOT deduplicate - self._insert_test_data( - strata=[(1, 1), (2, 1)], - constraints=[ - (1, "congressional_district_geoid", "=", "601"), - (2, "congressional_district_geoid", "=", "602"), - ], - targets=[ - (1, "snap", 1000, 2023), - (2, "snap", 1100, 2023), - ], - ) - - builder = SparseMatrixBuilder( - db_uri=self.db_uri, - time_period=2023, - cds_to_calibrate=["601", "602"], - ) - - targets_df = builder._query_targets({"stratum_group_ids": [1]}) - targets_df["geographic_id"] = targets_df["stratum_id"].apply( - builder._get_geographic_id - ) - targets_df["constraint_info"] = targets_df["stratum_id"].apply( - builder._get_constraint_info - ) - - result = builder._deduplicate_targets(targets_df) - - self.assertEqual(len(result), 2) # Both kept - self.assertEqual(len(builder.dedup_warnings), 0) - - def test_different_constraints_different_concepts(self): - """Test that different constraints create different concepts.""" - # Same variable but different age constraints - different concepts - self._insert_test_data( - strata=[(1, 1), (2, 1)], - constraints=[ - (1, "congressional_district_geoid", "=", "601"), - (1, "age", ">=", "5"), - (1, "age", "<", "18"), - (2, "congressional_district_geoid", "=", "601"), - (2, "age", ">=", "18"), - (2, "age", "<", "65"), - ], - targets=[ - (1, "person_count", 1000, 2023), - (2, "person_count", 2000, 2023), - ], - ) - - builder = SparseMatrixBuilder( - db_uri=self.db_uri, - time_period=2023, - cds_to_calibrate=["601"], - ) - - targets_df = builder._query_targets({"stratum_group_ids": [1]}) - targets_df["geographic_id"] = targets_df["stratum_id"].apply( - builder._get_geographic_id - ) - targets_df["constraint_info"] = targets_df["stratum_id"].apply( - builder._get_constraint_info - ) - - result = builder._deduplicate_targets(targets_df) - - self.assertEqual(len(result), 2) # Different concepts - self.assertEqual(len(builder.dedup_warnings), 0) - - def test_hierarchical_fallback_keeps_most_specific(self): - """Test hierarchical fallback mode keeps CD over state over national.""" - # Same concept at CD, state, and national levels - self._insert_test_data( - strata=[(1, 1), (2, 1), (3, 1)], - constraints=[ - (1, "congressional_district_geoid", "=", "601"), - (2, "state_fips", "=", "6"), - # stratum 3 has no geo constraint = national - ], - targets=[ - (1, "snap", 1200000, 2023), # CD level - (2, "snap", 15000000, 2023), # State level - (3, "snap", 110000000000, 2023), # National level - ], - ) - - builder = SparseMatrixBuilder( - db_uri=self.db_uri, - time_period=2023, - cds_to_calibrate=["601"], - ) - - targets_df = builder._query_targets({"stratum_group_ids": [1]}) - targets_df["geographic_id"] = targets_df["stratum_id"].apply( - builder._get_geographic_id - ) - targets_df["constraint_info"] = targets_df["stratum_id"].apply( - builder._get_constraint_info - ) - - result = builder._deduplicate_targets( - targets_df, mode="hierarchical_fallback" - ) - - self.assertEqual(len(result), 1) - # CD level should be kept (geo_priority=1) - self.assertEqual(result.iloc[0]["geographic_id"], "601") - self.assertEqual(result.iloc[0]["value"], 1200000) diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_period_selection_and_uprating.py b/policyengine_us_data/tests/test_local_area_calibration/test_period_selection_and_uprating.py new file mode 100644 index 000000000..639dc7369 --- /dev/null +++ b/policyengine_us_data/tests/test_local_area_calibration/test_period_selection_and_uprating.py @@ -0,0 +1,256 @@ +""" +Tests for best-period selection and uprating in SparseMatrixBuilder. +""" + +import unittest +import tempfile +import os +import pandas as pd +from sqlalchemy import create_engine, text + +from policyengine_us_data.datasets.cps.local_area_calibration.sparse_matrix_builder import ( + SparseMatrixBuilder, +) + + +class TestPeriodSelectionAndUprating(unittest.TestCase): + """Test best-period SQL CTE and uprating logic.""" + + @classmethod + def setUpClass(cls): + cls.temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + cls.db_path = cls.temp_db.name + cls.temp_db.close() + + cls.db_uri = f"sqlite:///{cls.db_path}" + engine = create_engine(cls.db_uri) + + with engine.connect() as conn: + conn.execute( + text( + "CREATE TABLE stratum_groups (" + "stratum_group_id INTEGER PRIMARY KEY, " + "name TEXT)" + ) + ) + conn.execute( + text( + "CREATE TABLE strata (" + "stratum_id INTEGER PRIMARY KEY, " + "stratum_group_id INTEGER)" + ) + ) + conn.execute( + text( + "CREATE TABLE stratum_constraints (" + "constraint_id INTEGER PRIMARY KEY, " + "stratum_id INTEGER, " + "constraint_variable TEXT, " + "operation TEXT, " + "value TEXT)" + ) + ) + conn.execute( + text( + "CREATE TABLE targets (" + "target_id INTEGER PRIMARY KEY, " + "stratum_id INTEGER, " + "variable TEXT, " + "value REAL, " + "period INTEGER)" + ) + ) + conn.commit() + + @classmethod + def tearDownClass(cls): + os.unlink(cls.db_path) + + def setUp(self): + engine = create_engine(self.db_uri) + with engine.connect() as conn: + conn.execute(text("DELETE FROM targets")) + conn.execute(text("DELETE FROM stratum_constraints")) + conn.execute(text("DELETE FROM strata")) + conn.execute(text("DELETE FROM stratum_groups")) + conn.commit() + + def _insert_test_data(self, strata, constraints, targets): + engine = create_engine(self.db_uri) + with engine.connect() as conn: + conn.execute( + text( + "INSERT OR IGNORE INTO stratum_groups " + "VALUES (1, 'test')" + ) + ) + for stratum_id, group_id in strata: + conn.execute( + text("INSERT INTO strata VALUES (:sid, :gid)"), + {"sid": stratum_id, "gid": group_id}, + ) + for i, (stratum_id, var, op, val) in enumerate(constraints): + conn.execute( + text( + "INSERT INTO stratum_constraints " + "VALUES (:cid, :sid, :var, :op, :val)" + ), + { + "cid": i + 1, + "sid": stratum_id, + "var": var, + "op": op, + "val": val, + }, + ) + for i, ( + stratum_id, + variable, + value, + period, + ) in enumerate(targets): + conn.execute( + text( + "INSERT INTO targets " + "VALUES (:tid, :sid, :var, :val, :period)" + ), + { + "tid": i + 1, + "sid": stratum_id, + "var": variable, + "val": value, + "period": period, + }, + ) + conn.commit() + + def _make_builder(self, time_period=2024): + return SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=time_period, + cds_to_calibrate=["601"], + ) + + # ---- Period selection tests ---- + + def test_best_period_prefers_past(self): + """Targets at 2022 and 2026 -> picks 2022 for time_period=2024.""" + self._insert_test_data( + strata=[(1, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 1000, 2022), + (1, "snap", 2000, 2026), + ], + ) + builder = self._make_builder(time_period=2024) + df = builder._query_targets({"stratum_group_ids": [1]}) + self.assertEqual(len(df), 1) + self.assertEqual(df.iloc[0]["period"], 2022) + self.assertEqual(df.iloc[0]["value"], 1000) + + def test_best_period_uses_future_when_no_past(self): + """Target only at 2026 -> picks 2026 for time_period=2024.""" + self._insert_test_data( + strata=[(1, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 5000, 2026), + ], + ) + builder = self._make_builder(time_period=2024) + df = builder._query_targets({"stratum_group_ids": [1]}) + self.assertEqual(len(df), 1) + self.assertEqual(df.iloc[0]["period"], 2026) + + def test_best_period_exact_match(self): + """Targets at 2022, 2024, 2026 -> picks 2024 exactly.""" + self._insert_test_data( + strata=[(1, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 1000, 2022), + (1, "snap", 1500, 2024), + (1, "snap", 2000, 2026), + ], + ) + builder = self._make_builder(time_period=2024) + df = builder._query_targets({"stratum_group_ids": [1]}) + self.assertEqual(len(df), 1) + self.assertEqual(df.iloc[0]["period"], 2024) + self.assertEqual(df.iloc[0]["value"], 1500) + + def test_independent_per_stratum_and_variable(self): + """Different strata/variables select independently.""" + self._insert_test_data( + strata=[(1, 1), (2, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (2, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 1000, 2024), + (1, "snap", 800, 2022), + (2, "person_count", 500, 2022), + (2, "person_count", 600, 2026), + ], + ) + builder = self._make_builder(time_period=2024) + df = builder._query_targets({"stratum_group_ids": [1]}) + self.assertEqual(len(df), 2) + snap_row = df[df["variable"] == "snap"].iloc[0] + self.assertEqual(snap_row["period"], 2024) + count_row = df[df["variable"] == "person_count"].iloc[0] + self.assertEqual(count_row["period"], 2022) + + # ---- Uprating info tests ---- + + def test_cpi_uprating_for_dollar_vars(self): + builder = self._make_builder(time_period=2024) + factors = {(2022, "cpi"): 1.06, (2022, "pop"): 1.01} + factor, type_ = builder._get_uprating_info("snap", 2022, factors) + self.assertAlmostEqual(factor, 1.06) + self.assertEqual(type_, "cpi") + + def test_pop_uprating_for_count_vars(self): + builder = self._make_builder(time_period=2024) + factors = {(2022, "cpi"): 1.06, (2022, "pop"): 1.01} + factor, type_ = builder._get_uprating_info( + "person_count", 2022, factors + ) + self.assertAlmostEqual(factor, 1.01) + self.assertEqual(type_, "pop") + + def test_no_uprating_for_current_period(self): + builder = self._make_builder(time_period=2024) + factors = {(2024, "cpi"): 1.0, (2024, "pop"): 1.0} + factor, type_ = builder._get_uprating_info("snap", 2024, factors) + self.assertAlmostEqual(factor, 1.0) + self.assertEqual(type_, "none") + + def test_pop_uprating_households_variable(self): + builder = self._make_builder(time_period=2024) + factors = {(2022, "cpi"): 1.06, (2022, "pop"): 1.02} + factor, type_ = builder._get_uprating_info("households", 2022, factors) + self.assertAlmostEqual(factor, 1.02) + self.assertEqual(type_, "pop") + + def test_pop_uprating_tax_units_variable(self): + builder = self._make_builder(time_period=2024) + factors = {(2022, "cpi"): 1.06, (2022, "pop"): 1.02} + factor, type_ = builder._get_uprating_info("tax_units", 2022, factors) + self.assertAlmostEqual(factor, 1.02) + self.assertEqual(type_, "pop") + + def test_missing_factor_defaults_to_1(self): + builder = self._make_builder(time_period=2024) + factors = {} + factor, type_ = builder._get_uprating_info("snap", 2020, factors) + self.assertAlmostEqual(factor, 1.0) + self.assertEqual(type_, "cpi")