diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..bfd1545c 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: minor + changes: + added: + - Deduplication logic in SparseMatrixBuilder (option to remove duplicate targets or select most specific geographic level). + - Entity aware target calculations for correct entity counts. \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py b/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py index b22b8eb4..d01519a9 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py @@ -114,9 +114,14 @@ }, ) -print(f"Matrix shape: {X_sparse.shape}") +builder.print_uprating_summary(targets_df) + +print(f"\nMatrix shape: {X_sparse.shape}") print(f"Targets: {len(targets_df)}") +# ============================================================================ +# STEP 2: FILTER TO ACHIEVABLE TARGETS +# ============================================================================ # Filter to achievable targets (rows with non-zero data) row_sums = np.array(X_sparse.sum(axis=1)).flatten() achievable_mask = row_sums > 0 @@ -129,7 +134,7 @@ targets_df = targets_df[achievable_mask].reset_index(drop=True) X_sparse = X_sparse[achievable_mask, :] -print(f"Filtered matrix shape: {X_sparse.shape}") +print(f"Final matrix shape: {X_sparse.shape}") # Extract target vector and names targets = targets_df["value"].values @@ -139,14 +144,14 @@ ] # ============================================================================ -# STEP 2: INITIALIZE WEIGHTS +# STEP 3: INITIALIZE WEIGHTS # ============================================================================ initial_weights = np.ones(X_sparse.shape[1]) * 100 print(f"\nInitial weights shape: {initial_weights.shape}") print(f"Initial weights sum: {initial_weights.sum():,.0f}") # ============================================================================ -# STEP 3: CREATE MODEL +# STEP 4: CREATE MODEL # ============================================================================ print("\nCreating SparseCalibrationWeights model...") model = SparseCalibrationWeights( @@ -162,7 +167,7 @@ ) # ============================================================================ -# STEP 4: TRAIN IN CHUNKS +# STEP 5: TRAIN IN CHUNKS # ============================================================================ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") calibration_log = pd.DataFrame() @@ -205,7 +210,7 @@ calibration_log = pd.concat([calibration_log, chunk_df], ignore_index=True) # ============================================================================ -# STEP 5: EXTRACT AND SAVE WEIGHTS +# STEP 6: EXTRACT AND SAVE WEIGHTS # ============================================================================ with torch.no_grad(): w = model.get_weights(deterministic=True).cpu().numpy() @@ -225,7 +230,7 @@ print(f"LOG_PATH:{log_path}") # ============================================================================ -# STEP 6: VERIFY PREDICTIONS +# STEP 7: VERIFY PREDICTIONS # ============================================================================ print("\n" + "=" * 60) print("PREDICTION VERIFICATION") diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py index 5ffe5474..5d2210c1 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py @@ -138,8 +138,123 @@ def _evaluate_constraints_entity_aware( return household_mask + def _calculate_target_values_entity_aware( + self, + state_sim, + target_variable: str, + non_geo_constraints: List[dict], + geo_mask: np.ndarray, + n_households: int, + ) -> np.ndarray: + """ + Calculate target values at household level, handling count targets. + + For count targets (*_count): Count entities per household satisfying + constraints + For value targets: Sum values at household level (existing behavior) + + Args: + state_sim: Microsimulation with state_fips set + target_variable: The target variable name (e.g., "snap", + "person_count") + non_geo_constraints: List of constraint dicts (geographic + constraints should be pre-filtered) + geo_mask: Boolean mask array for geographic filtering (household + level) + n_households: Number of households + + Returns: + Float array of target values at household level + """ + is_count_target = target_variable.endswith("_count") + + if not is_count_target: + # Value target: use existing entity-aware constraint evaluation + entity_mask = self._evaluate_constraints_entity_aware( + state_sim, non_geo_constraints, n_households + ) + mask = geo_mask & entity_mask + + target_values = state_sim.calculate( + target_variable, map_to="household" + ).values + return (target_values * mask).astype(np.float32) + + # Count target: need to count entities satisfying constraints + entity_rel = self._build_entity_relationship(state_sim) + n_persons = len(entity_rel) + + # Evaluate constraints at person level (don't aggregate to HH yet) + person_mask = np.ones(n_persons, dtype=bool) + for c in non_geo_constraints: + constraint_values = state_sim.calculate( + c["variable"], map_to="person" + ).values + person_mask &= apply_op( + constraint_values, c["operation"], c["value"] + ) + + # Get target entity from variable definition + target_entity = state_sim.tax_benefit_system.variables[ + target_variable + ].entity.key + + household_ids = state_sim.calculate( + "household_id", map_to="household" + ).values + geo_mask_map = dict(zip(household_ids, geo_mask)) + + if target_entity == "household": + # household_count: 1 per qualifying household + if non_geo_constraints: + entity_mask = self._evaluate_constraints_entity_aware( + state_sim, non_geo_constraints, n_households + ) + return (geo_mask & entity_mask).astype(np.float32) + return geo_mask.astype(np.float32) + + if target_entity == "person": + # Count persons satisfying constraints per household + entity_rel["satisfies"] = person_mask + entity_rel["geo_ok"] = entity_rel["household_id"].map(geo_mask_map) + filtered = entity_rel[ + entity_rel["satisfies"] & entity_rel["geo_ok"] + ] + counts = filtered.groupby("household_id")["person_id"].nunique() + else: + # For tax_unit, spm_unit: aggregate person mask to entity, then + # count + entity_id_col = f"{target_entity}_id" + entity_rel["satisfies"] = person_mask + entity_satisfies = entity_rel.groupby(entity_id_col)[ + "satisfies" + ].any() + + entity_rel_unique = entity_rel[ + ["household_id", entity_id_col] + ].drop_duplicates() + entity_rel_unique["entity_ok"] = entity_rel_unique[ + entity_id_col + ].map(entity_satisfies) + entity_rel_unique["geo_ok"] = entity_rel_unique[ + "household_id" + ].map(geo_mask_map) + filtered = entity_rel_unique[ + entity_rel_unique["entity_ok"] & entity_rel_unique["geo_ok"] + ] + counts = filtered.groupby("household_id")[entity_id_col].nunique() + + # Build result aligned with household order + return np.array( + [counts.get(hh_id, 0) for hh_id in household_ids], dtype=np.float32 + ) + def _query_targets(self, target_filter: dict) -> pd.DataFrame: - """Query targets based on filter criteria using OR logic.""" + """Query targets, selecting best period per (stratum_id, variable). + + Best period: most recent period <= self.time_period, or closest + future period if none exists. + """ or_conditions = [] if "stratum_group_ids" in target_filter: @@ -159,23 +274,43 @@ def _query_targets(self, target_filter: dict) -> pd.DataFrame: or_conditions.append(f"t.stratum_id IN ({ids})") if not or_conditions: - raise ValueError( - "target_filter must specify at least one filter criterion" - ) - - where_clause = " OR ".join(f"({c})" for c in or_conditions) + where_clause = "1=1" + else: + where_clause = " OR ".join(f"({c})" for c in or_conditions) query = f""" - SELECT t.target_id, t.stratum_id, t.variable, t.value, t.period, - s.stratum_group_id - FROM targets t - JOIN strata s ON t.stratum_id = s.stratum_id - WHERE {where_clause} - ORDER BY t.target_id + WITH filtered_targets AS ( + SELECT t.target_id, t.stratum_id, t.variable, t.value, + t.period, s.stratum_group_id + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + WHERE {where_clause} + ), + best_periods AS ( + SELECT stratum_id, variable, + CASE + WHEN MAX(CASE WHEN period <= :time_period + THEN period END) IS NOT NULL + THEN MAX(CASE WHEN period <= :time_period + THEN period END) + ELSE MIN(period) + END as best_period + FROM filtered_targets + GROUP BY stratum_id, variable + ) + SELECT ft.* + FROM filtered_targets ft + JOIN best_periods bp + ON ft.stratum_id = bp.stratum_id + AND ft.variable = bp.variable + AND ft.period = bp.best_period + ORDER BY ft.target_id """ with self.engine.connect() as conn: - return pd.read_sql(query, conn) + return pd.read_sql( + query, conn, params={"time_period": self.time_period} + ) def _get_constraints(self, stratum_id: int) -> List[dict]: """Get all constraints for a stratum (including geographic).""" @@ -198,6 +333,104 @@ def _get_geographic_id(self, stratum_id: int) -> str: return c["value"] return "US" + def _calculate_uprating_factors(self, params) -> dict: + """Calculate CPI and population uprating factors for all periods.""" + factors = {} + + query = "SELECT DISTINCT period FROM targets WHERE period IS NOT NULL ORDER BY period" + with self.engine.connect() as conn: + result = conn.execute(text(query)) + years_needed = [row[0] for row in result] + + logger.info( + f"Calculating uprating factors for years " + f"{years_needed} to {self.time_period}" + ) + + for from_year in years_needed: + if from_year == self.time_period: + factors[(from_year, "cpi")] = 1.0 + factors[(from_year, "pop")] = 1.0 + continue + + try: + cpi_from = params.gov.bls.cpi.cpi_u(from_year) + cpi_to = params.gov.bls.cpi.cpi_u(self.time_period) + factors[(from_year, "cpi")] = float(cpi_to / cpi_from) + except Exception as e: + logger.warning( + f"Could not calculate CPI factor for " f"{from_year}: {e}" + ) + factors[(from_year, "cpi")] = 1.0 + + try: + pop_from = params.calibration.gov.census.populations.total( + from_year + ) + pop_to = params.calibration.gov.census.populations.total( + self.time_period + ) + factors[(from_year, "pop")] = float(pop_to / pop_from) + except Exception as e: + logger.warning( + f"Could not calculate population factor for " + f"{from_year}: {e}" + ) + factors[(from_year, "pop")] = 1.0 + + for (year, type_), factor in sorted(factors.items()): + if factor != 1.0: + logger.info( + f" {year} -> {self.time_period} " + f"({type_}): {factor:.4f}" + ) + + return factors + + def _get_uprating_info( + self, + variable: str, + period: int, + factors: dict, + ) -> Tuple[float, str]: + """Get uprating factor and type for a variable at a given period.""" + if period == self.time_period: + return 1.0, "none" + + count_indicators = [ + "count", + "person", + "people", + "households", + "tax_units", + ] + is_count = any(ind in variable.lower() for ind in count_indicators) + uprating_type = "pop" if is_count else "cpi" + + factor = factors.get((period, uprating_type), 1.0) + return factor, uprating_type + + def print_uprating_summary(self, targets_df: pd.DataFrame) -> None: + """Print summary of uprating applied to targets.""" + uprated = targets_df[targets_df["uprating_factor"] != 1.0] + if len(uprated) == 0: + print("No targets were uprated.") + return + + print("\n" + "=" * 60) + print("UPRATING SUMMARY") + print("=" * 60) + print(f"Uprated {len(uprated)} of {len(targets_df)} targets") + + period_counts = uprated["period"].value_counts().sort_index() + for period, count in period_counts.items(): + print(f" Period {period}: {count} targets") + + factors = uprated["uprating_factor"] + print( + f" Factor range: [{factors.min():.4f}, " f"{factors.max():.4f}]" + ) + def _create_state_sim(self, state: int, n_households: int): """Create a fresh simulation with state_fips set to given state.""" from policyengine_us import Microsimulation @@ -213,16 +446,20 @@ def _create_state_sim(self, state: int, n_households: int): return state_sim def build_matrix( - self, sim, target_filter: dict + self, + sim, + target_filter: dict, ) -> Tuple[pd.DataFrame, sparse.csr_matrix, Dict[str, List[str]]]: """ Build sparse calibration matrix. Args: - sim: Microsimulation instance (used for household_ids, or as template) + sim: Microsimulation instance (used for household_ids, or + as template) target_filter: Dict specifying which targets to include - {"stratum_group_ids": [4]} for SNAP targets - {"target_ids": [123, 456]} for specific targets + - an empty dict {} will fetch all targets Returns: Tuple of (targets_df, X_sparse, household_id_mapping) @@ -235,16 +472,31 @@ def build_matrix( n_cols = n_households * n_cds targets_df = self._query_targets(target_filter) - n_targets = len(targets_df) - if n_targets == 0: + if len(targets_df) == 0: raise ValueError("No targets found matching filter") targets_df["geographic_id"] = targets_df["stratum_id"].apply( self._get_geographic_id ) - # Sort by (geo_level, variable, geographic_id) for contiguous group rows + # Uprate targets from their original period to self.time_period + params = sim.tax_benefit_system.parameters + uprating_factors = self._calculate_uprating_factors(params) + targets_df["original_value"] = targets_df["value"].copy() + targets_df["uprating_factor"] = targets_df.apply( + lambda row: self._get_uprating_info( + row["variable"], row["period"], uprating_factors + )[0], + axis=1, + ) + targets_df["value"] = ( + targets_df["original_value"] * targets_df["uprating_factor"] + ) + + n_targets = len(targets_df) + + # Sort by (geo_level, variable, geographic_id) for contiguous group targets_df["_geo_level"] = targets_df["geographic_id"].apply( _get_geo_level ) @@ -316,24 +568,20 @@ def build_matrix( if not geo_mask.any(): continue - # Evaluate non-geographic constraints at entity level - entity_mask = self._evaluate_constraints_entity_aware( - state_sim, non_geo_constraints, n_households + # Calculate target values with entity-aware handling + # This properly handles count targets (*_count) by counting + # entities rather than summing values + masked_values = self._calculate_target_values_entity_aware( + state_sim, + target["variable"], + non_geo_constraints, + geo_mask, + n_households, ) - # Combine geographic and entity-aware masks - mask = geo_mask & entity_mask - - if not mask.any(): + if not masked_values.any(): continue - target_values = state_sim.calculate( - target["variable"], - self.time_period, - map_to="household", - ).values - masked_values = (target_values * mask).astype(np.float32) - nonzero = np.where(masked_values != 0)[0] if len(nonzero) > 0: X[row_idx, col_start + nonzero] = masked_values[ diff --git a/policyengine_us_data/tests/test_local_area_calibration/conftest.py b/policyengine_us_data/tests/test_local_area_calibration/conftest.py index 20b7f05b..d2708906 100644 --- a/policyengine_us_data/tests/test_local_area_calibration/conftest.py +++ b/policyengine_us_data/tests/test_local_area_calibration/conftest.py @@ -23,6 +23,11 @@ # Format: (variable_name, rtol) # variable_name as per the targets in policy_data.db # rtol is relative tolerance for comparison +# +# NOTE: Count targets (person_count, tax_unit_count) are excluded because +# they have constraints (e.g., age>=5|age<18) that make the X_sparse values +# different from raw sim.calculate() values. Count targets are tested +# separately in test_count_targets.py with controlled mock data. VARIABLES_TO_TEST = [ ("snap", 1e-2), ("income_tax", 1e-2), diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_count_targets.py b/policyengine_us_data/tests/test_local_area_calibration/test_count_targets.py new file mode 100644 index 00000000..46eae4eb --- /dev/null +++ b/policyengine_us_data/tests/test_local_area_calibration/test_count_targets.py @@ -0,0 +1,415 @@ +""" +Tests for count target handling in SparseMatrixBuilder. + +These tests verify that count targets (e.g., person_count, tax_unit_count) +are correctly handled by counting entities that satisfy constraints, rather +than summing values. +""" + +import pytest +import numpy as np +from dataclasses import dataclass + +from policyengine_us_data.datasets.cps.local_area_calibration.sparse_matrix_builder import ( + SparseMatrixBuilder, +) + + +@dataclass +class MockEntity: + """Mock entity with a key attribute.""" + + key: str + + +@dataclass +class MockVariable: + """Mock variable with entity information.""" + + entity: MockEntity + + @classmethod + def create(cls, entity_key: str) -> "MockVariable": + return cls(entity=MockEntity(key=entity_key)) + + +class MockTaxBenefitSystem: + """Mock tax benefit system with variable definitions.""" + + def __init__(self): + self.variables = { + "person_count": MockVariable.create("person"), + "tax_unit_count": MockVariable.create("tax_unit"), + "household_count": MockVariable.create("household"), + "spm_unit_count": MockVariable.create("spm_unit"), + "snap": MockVariable.create("spm_unit"), + } + + +@dataclass +class MockCalculationResult: + """Mock result from simulation.calculate().""" + + values: np.ndarray + + +class MockSimulation: + """Mock simulation for testing count target calculations.""" + + def __init__(self, entity_data: dict, variable_values: dict): + """ + Args: + entity_data: Dict with person_id, household_id, tax_unit_id, + spm_unit_id arrays (all at person level) + variable_values: Dict mapping variable names to their values + at the appropriate entity level + """ + self.entity_data = entity_data + self.variable_values = variable_values + self.tax_benefit_system = MockTaxBenefitSystem() + + def calculate(self, variable: str, map_to: str = None): + """Return mock calculation result.""" + if variable in self.entity_data: + # Entity ID variables + if map_to == "person": + values = np.array(self.entity_data[variable]) + elif map_to == "household": + # Return unique household IDs + values = np.array( + sorted(set(self.entity_data["household_id"])) + ) + else: + values = np.array(self.entity_data[variable]) + elif variable in self.variable_values: + # Regular variables - return at requested level + val_data = self.variable_values[variable] + if map_to == "person": + values = np.array(val_data["person"]) + elif map_to == "household": + values = np.array(val_data["household"]) + else: + values = np.array(val_data.get("default", [])) + else: + values = np.array([]) + + return MockCalculationResult(values=values) + + +@pytest.fixture +def basic_entity_data(): + """ + Create mock entity relationships with known household compositions. + + Household 1 (id=100): 3 people (ages 5, 12, 40) -> 2 aged 5-17 + Household 2 (id=200): 2 people (ages 3, 25) -> 0 aged 5-17 + Household 3 (id=300): 4 people (ages 6, 8, 10, 45) -> 3 aged 5-17 + """ + return { + "person_id": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "household_id": [100, 100, 100, 200, 200, 300, 300, 300, 300], + "tax_unit_id": [10, 10, 10, 20, 20, 30, 30, 30, 30], + "spm_unit_id": [ + 1000, + 1000, + 1000, + 2000, + 2000, + 3000, + 3000, + 3000, + 3000, + ], + } + + +@pytest.fixture +def basic_variable_values(): + """Variable values for basic household composition tests.""" + return { + "age": { + "person": [5, 12, 40, 3, 25, 6, 8, 10, 45], + "household": [40, 25, 45], # Not used for age constraints + }, + "person_count": { + "person": [1, 1, 1, 1, 1, 1, 1, 1, 1], + "household": [3, 2, 4], # Sum per household + }, + "snap": { + "person": [100, 100, 100, 0, 0, 200, 200, 200, 200], + "household": [300, 0, 800], + }, + } + + +@pytest.fixture +def basic_sim(basic_entity_data, basic_variable_values): + """Mock simulation with basic household compositions.""" + return MockSimulation(basic_entity_data, basic_variable_values) + + +@pytest.fixture +def builder(): + """Create a minimal SparseMatrixBuilder (won't use DB for unit tests).""" + return SparseMatrixBuilder( + db_uri="sqlite:///:memory:", + time_period=2023, + cds_to_calibrate=["101"], + ) + + +# Tests for basic count target calculation +class TestCountTargetCalculation: + """Test _calculate_target_values_entity_aware for count targets.""" + + def test_person_count_with_age_constraints(self, builder, basic_sim): + """Test person_count correctly counts persons in age range per HH.""" + # Constraints: age >= 5 AND age < 18 + constraints = [ + {"variable": "age", "operation": ">=", "value": 5}, + {"variable": "age", "operation": "<", "value": 18}, + ] + + geo_mask = np.array([True, True, True]) # All households included + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 2 people (ages 5, 12), HH2 has 0, HH3 has 3 (6,8,10) + expected = np.array([2, 0, 3], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_no_constraints(self, builder, basic_sim): + """Test person_count without constraints returns all persons per HH.""" + constraints = [] + geo_mask = np.array([True, True, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 3 people, HH2 has 2, HH3 has 4 + expected = np.array([3, 2, 4], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_with_geo_mask(self, builder, basic_sim): + """Test person_count respects geographic mask.""" + constraints = [ + {"variable": "age", "operation": ">=", "value": 5}, + {"variable": "age", "operation": "<", "value": 18}, + ] + + # Only include households 1 and 3 + geo_mask = np.array([True, False, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1=2, HH2=0 (masked out), HH3=3 + expected = np.array([2, 0, 3], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_value_target_uses_sum(self, builder, basic_sim): + """Test that non-count targets sum values (existing behavior).""" + # SNAP is a value target, not a count target + constraints = [] + geo_mask = np.array([True, True, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "snap", + constraints, + geo_mask, + n_households, + ) + + # Expected: Sum of snap values per household + expected = np.array([300, 0, 800], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_household_count_no_constraints(self, builder, basic_sim): + """Test household_count returns 1 for each qualifying household.""" + constraints = [] + geo_mask = np.array([True, True, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "household_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: 1 for each household in geo_mask + expected = np.array([1, 1, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_household_count_with_geo_mask(self, builder, basic_sim): + """Test household_count respects geographic mask.""" + constraints = [] + geo_mask = np.array([True, False, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "household_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: 1 for HH1, 0 for HH2 (masked), 1 for HH3 + expected = np.array([1, 0, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + +# Fixtures for complex entity relationship tests +@pytest.fixture +def complex_entity_data(): + """ + Create entity data with multiple tax units per household. + + Household 1 (id=100): 4 people in 2 tax units + Tax unit 10: person 1 (age 30, filer), person 2 (age 28) + Tax unit 11: person 3 (age 65, filer), person 4 (age 62) + Household 2 (id=200): 2 people in 1 tax unit + Tax unit 20: person 5 (age 45, filer), person 6 (age 16) + """ + return { + "person_id": [1, 2, 3, 4, 5, 6], + "household_id": [100, 100, 100, 100, 200, 200], + "tax_unit_id": [10, 10, 11, 11, 20, 20], + "spm_unit_id": [1000, 1000, 1000, 1000, 2000, 2000], + } + + +@pytest.fixture +def complex_variable_values(): + """Variable values for complex entity relationship tests.""" + return { + "age": { + "person": [30, 28, 65, 62, 45, 16], + "household": [65, 45], + }, + "is_tax_unit_head": { + "person": [True, False, True, False, True, False], + "household": [2, 1], # count of heads per HH + }, + "tax_unit_count": { + "person": [1, 1, 1, 1, 1, 1], + "household": [2, 1], + }, + "person_count": { + "person": [1, 1, 1, 1, 1, 1], + "household": [4, 2], + }, + } + + +@pytest.fixture +def complex_sim(complex_entity_data, complex_variable_values): + """Mock simulation with complex entity relationships.""" + return MockSimulation(complex_entity_data, complex_variable_values) + + +# Tests for complex entity relationships +class TestCountTargetWithRealEntities: + """Test count targets with more complex entity relationships.""" + + def test_tax_unit_count_no_constraints(self, builder, complex_sim): + """Test tax_unit_count counts all tax units per household.""" + constraints = [] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "tax_unit_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 2 tax units, HH2 has 1 + expected = np.array([2, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_tax_unit_count_with_age_constraint(self, builder, complex_sim): + """Test tax_unit_count with age constraint on members.""" + # Count tax units that have at least one person aged >= 65 + constraints = [ + {"variable": "age", "operation": ">=", "value": 65}, + ] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "tax_unit_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 1 tax unit (TU 11) with person >=65, HH2 has 0 + expected = np.array([1, 0], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_seniors(self, builder, complex_sim): + """Test person_count for seniors (age >= 65).""" + constraints = [ + {"variable": "age", "operation": ">=", "value": 65}, + ] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 1 senior (age 65), HH2 has 0 + expected = np.array([1, 0], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_children(self, builder, complex_sim): + """Test person_count for children (age < 18).""" + constraints = [ + {"variable": "age", "operation": "<", "value": 18}, + ] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 0 children, HH2 has 1 (age 16) + expected = np.array([0, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_period_selection_and_uprating.py b/policyengine_us_data/tests/test_local_area_calibration/test_period_selection_and_uprating.py new file mode 100644 index 00000000..639dc736 --- /dev/null +++ b/policyengine_us_data/tests/test_local_area_calibration/test_period_selection_and_uprating.py @@ -0,0 +1,256 @@ +""" +Tests for best-period selection and uprating in SparseMatrixBuilder. +""" + +import unittest +import tempfile +import os +import pandas as pd +from sqlalchemy import create_engine, text + +from policyengine_us_data.datasets.cps.local_area_calibration.sparse_matrix_builder import ( + SparseMatrixBuilder, +) + + +class TestPeriodSelectionAndUprating(unittest.TestCase): + """Test best-period SQL CTE and uprating logic.""" + + @classmethod + def setUpClass(cls): + cls.temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + cls.db_path = cls.temp_db.name + cls.temp_db.close() + + cls.db_uri = f"sqlite:///{cls.db_path}" + engine = create_engine(cls.db_uri) + + with engine.connect() as conn: + conn.execute( + text( + "CREATE TABLE stratum_groups (" + "stratum_group_id INTEGER PRIMARY KEY, " + "name TEXT)" + ) + ) + conn.execute( + text( + "CREATE TABLE strata (" + "stratum_id INTEGER PRIMARY KEY, " + "stratum_group_id INTEGER)" + ) + ) + conn.execute( + text( + "CREATE TABLE stratum_constraints (" + "constraint_id INTEGER PRIMARY KEY, " + "stratum_id INTEGER, " + "constraint_variable TEXT, " + "operation TEXT, " + "value TEXT)" + ) + ) + conn.execute( + text( + "CREATE TABLE targets (" + "target_id INTEGER PRIMARY KEY, " + "stratum_id INTEGER, " + "variable TEXT, " + "value REAL, " + "period INTEGER)" + ) + ) + conn.commit() + + @classmethod + def tearDownClass(cls): + os.unlink(cls.db_path) + + def setUp(self): + engine = create_engine(self.db_uri) + with engine.connect() as conn: + conn.execute(text("DELETE FROM targets")) + conn.execute(text("DELETE FROM stratum_constraints")) + conn.execute(text("DELETE FROM strata")) + conn.execute(text("DELETE FROM stratum_groups")) + conn.commit() + + def _insert_test_data(self, strata, constraints, targets): + engine = create_engine(self.db_uri) + with engine.connect() as conn: + conn.execute( + text( + "INSERT OR IGNORE INTO stratum_groups " + "VALUES (1, 'test')" + ) + ) + for stratum_id, group_id in strata: + conn.execute( + text("INSERT INTO strata VALUES (:sid, :gid)"), + {"sid": stratum_id, "gid": group_id}, + ) + for i, (stratum_id, var, op, val) in enumerate(constraints): + conn.execute( + text( + "INSERT INTO stratum_constraints " + "VALUES (:cid, :sid, :var, :op, :val)" + ), + { + "cid": i + 1, + "sid": stratum_id, + "var": var, + "op": op, + "val": val, + }, + ) + for i, ( + stratum_id, + variable, + value, + period, + ) in enumerate(targets): + conn.execute( + text( + "INSERT INTO targets " + "VALUES (:tid, :sid, :var, :val, :period)" + ), + { + "tid": i + 1, + "sid": stratum_id, + "var": variable, + "val": value, + "period": period, + }, + ) + conn.commit() + + def _make_builder(self, time_period=2024): + return SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=time_period, + cds_to_calibrate=["601"], + ) + + # ---- Period selection tests ---- + + def test_best_period_prefers_past(self): + """Targets at 2022 and 2026 -> picks 2022 for time_period=2024.""" + self._insert_test_data( + strata=[(1, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 1000, 2022), + (1, "snap", 2000, 2026), + ], + ) + builder = self._make_builder(time_period=2024) + df = builder._query_targets({"stratum_group_ids": [1]}) + self.assertEqual(len(df), 1) + self.assertEqual(df.iloc[0]["period"], 2022) + self.assertEqual(df.iloc[0]["value"], 1000) + + def test_best_period_uses_future_when_no_past(self): + """Target only at 2026 -> picks 2026 for time_period=2024.""" + self._insert_test_data( + strata=[(1, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 5000, 2026), + ], + ) + builder = self._make_builder(time_period=2024) + df = builder._query_targets({"stratum_group_ids": [1]}) + self.assertEqual(len(df), 1) + self.assertEqual(df.iloc[0]["period"], 2026) + + def test_best_period_exact_match(self): + """Targets at 2022, 2024, 2026 -> picks 2024 exactly.""" + self._insert_test_data( + strata=[(1, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 1000, 2022), + (1, "snap", 1500, 2024), + (1, "snap", 2000, 2026), + ], + ) + builder = self._make_builder(time_period=2024) + df = builder._query_targets({"stratum_group_ids": [1]}) + self.assertEqual(len(df), 1) + self.assertEqual(df.iloc[0]["period"], 2024) + self.assertEqual(df.iloc[0]["value"], 1500) + + def test_independent_per_stratum_and_variable(self): + """Different strata/variables select independently.""" + self._insert_test_data( + strata=[(1, 1), (2, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (2, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 1000, 2024), + (1, "snap", 800, 2022), + (2, "person_count", 500, 2022), + (2, "person_count", 600, 2026), + ], + ) + builder = self._make_builder(time_period=2024) + df = builder._query_targets({"stratum_group_ids": [1]}) + self.assertEqual(len(df), 2) + snap_row = df[df["variable"] == "snap"].iloc[0] + self.assertEqual(snap_row["period"], 2024) + count_row = df[df["variable"] == "person_count"].iloc[0] + self.assertEqual(count_row["period"], 2022) + + # ---- Uprating info tests ---- + + def test_cpi_uprating_for_dollar_vars(self): + builder = self._make_builder(time_period=2024) + factors = {(2022, "cpi"): 1.06, (2022, "pop"): 1.01} + factor, type_ = builder._get_uprating_info("snap", 2022, factors) + self.assertAlmostEqual(factor, 1.06) + self.assertEqual(type_, "cpi") + + def test_pop_uprating_for_count_vars(self): + builder = self._make_builder(time_period=2024) + factors = {(2022, "cpi"): 1.06, (2022, "pop"): 1.01} + factor, type_ = builder._get_uprating_info( + "person_count", 2022, factors + ) + self.assertAlmostEqual(factor, 1.01) + self.assertEqual(type_, "pop") + + def test_no_uprating_for_current_period(self): + builder = self._make_builder(time_period=2024) + factors = {(2024, "cpi"): 1.0, (2024, "pop"): 1.0} + factor, type_ = builder._get_uprating_info("snap", 2024, factors) + self.assertAlmostEqual(factor, 1.0) + self.assertEqual(type_, "none") + + def test_pop_uprating_households_variable(self): + builder = self._make_builder(time_period=2024) + factors = {(2022, "cpi"): 1.06, (2022, "pop"): 1.02} + factor, type_ = builder._get_uprating_info("households", 2022, factors) + self.assertAlmostEqual(factor, 1.02) + self.assertEqual(type_, "pop") + + def test_pop_uprating_tax_units_variable(self): + builder = self._make_builder(time_period=2024) + factors = {(2022, "cpi"): 1.06, (2022, "pop"): 1.02} + factor, type_ = builder._get_uprating_info("tax_units", 2022, factors) + self.assertAlmostEqual(factor, 1.02) + self.assertEqual(type_, "pop") + + def test_missing_factor_defaults_to_1(self): + builder = self._make_builder(time_period=2024) + factors = {} + factor, type_ = builder._get_uprating_info("snap", 2020, factors) + self.assertAlmostEqual(factor, 1.0) + self.assertEqual(type_, "cpi")