From 7de3a1e0e1c2d172d35965389135892ccbcbfb42 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 17 May 2026 14:04:37 -0400 Subject: [PATCH 1/2] Make simulation input exports safe by default --- changelog.d/safe-input-export.changed.md | 1 + policyengine_core/simulations/simulation.py | 94 ++++++++++++++++- tests/core/test_simulations.py | 106 ++++++++++++++++++++ 3 files changed, 196 insertions(+), 5 deletions(-) create mode 100644 changelog.d/safe-input-export.changed.md diff --git a/changelog.d/safe-input-export.changed.md b/changelog.d/safe-input-export.changed.md new file mode 100644 index 00000000..186551b5 --- /dev/null +++ b/changelog.d/safe-input-export.changed.md @@ -0,0 +1 @@ +Exclude pseudo-inputs and calculated values from simulation input exports by default. diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 2bdfe045..5fed0140 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -1588,10 +1588,82 @@ def check_macro_cache(self, variable_name: str, period: str) -> bool: return True + def get_input_variables(self, include_computed_variables: bool = True) -> List[str]: + """Return variable names stored as inputs on this simulation. + + Args: + include_computed_variables: When ``True``, return the legacy + runtime list of variables with stored values. When ``False``, + return only structurally input variables that were populated + through ``set_input`` on the current branch. + + Returns: + List[str]: Stored input variable names. + """ + if include_computed_variables: + return list(self.input_variables) + + return [ + variable_name + for variable_name in self.tax_benefit_system.variables + if len( + self._get_exportable_input_periods( + variable_name, + include_computed_variables=False, + ) + ) + > 0 + ] + + @property + def true_input_variables(self) -> List[str]: + """Stored variables that are safe to reload as source inputs.""" + return self.get_input_variables(include_computed_variables=False) + + def _is_exportable_input_variable(self, variable_name: str) -> bool: + variable = self.tax_benefit_system.get_variable(variable_name) + return variable is not None and variable.is_input_variable() + + def _get_exportable_input_periods( + self, + variable_name: str, + include_computed_variables: bool, + ) -> List[Period]: + if include_computed_variables: + return self.get_holder(variable_name).get_known_periods() + + if not self._is_exportable_input_variable(variable_name): + return [] + + user_input_periods = { + period + for input_variable_name, branch_name, period in getattr( + self, "_user_input_keys", set() + ) + if input_variable_name == variable_name and branch_name == self.branch_name + } + if not user_input_periods: + return [] + variable = self.tax_benefit_system.get_variable(variable_name) + holder = self.get_holder(variable_name) + if variable.definition_period == ETERNITY: + return holder.get_known_periods() + known_periods = set(holder.get_known_periods()) + return sorted(user_input_periods & known_periods, key=str) + def to_input_dataframe( self, + include_computed_variables: bool = False, ) -> pd.DataFrame: - """Exports a DataFrame which can be loaded back to a new Simulation to reproduce the same results. + """Exports a DataFrame that can be loaded back into a new Simulation. + + By default, only structurally input variables populated through + ``set_input`` are exported. This avoids serializing pseudo-inputs and + stale calculated values that would override formulas when reloaded. + + Args: + include_computed_variables: If ``True``, export every variable with + a known period, matching the historical unsafe behavior. Returns: pd.DataFrame: The DataFrame containing the input values. @@ -1601,7 +1673,9 @@ def to_input_dataframe( for variable in self.tax_benefit_system.variables: variable_meta = self.tax_benefit_system.variables[variable] - for period in self.get_holder(variable).get_known_periods(): + for period in self._get_exportable_input_periods( + variable, include_computed_variables + ): # Test if period matches entity definition period if variable_meta.definition_period != period.unit: continue @@ -1611,8 +1685,16 @@ def to_input_dataframe( return df - def to_input_dict(self) -> dict: - """Exports a dictionary which can be loaded back to a new Simulation to reproduce the same results. + def to_input_dict(self, include_computed_variables: bool = False) -> dict: + """Exports a dictionary that can be loaded back into a new Simulation. + + By default, only structurally input variables populated through + ``set_input`` are exported. This avoids serializing pseudo-inputs and + stale calculated values that would override formulas when reloaded. + + Args: + include_computed_variables: If ``True``, export every variable with + a known period, matching the historical unsafe behavior. Returns: dict: The dictionary containing the input values. @@ -1621,7 +1703,9 @@ def to_input_dict(self) -> dict: for variable in self.tax_benefit_system.variables: data[variable] = {} - for period in self.get_holder(variable).get_known_periods(): + for period in self._get_exportable_input_periods( + variable, include_computed_variables + ): values = self.calculate(variable, period, map_to="person") if values is not None: data[variable][str(period)] = values.tolist() diff --git a/tests/core/test_simulations.py b/tests/core/test_simulations.py index ba85f0e8..fdd16f8f 100644 --- a/tests/core/test_simulations.py +++ b/tests/core/test_simulations.py @@ -1,4 +1,9 @@ from policyengine_core.country_template.situation_examples import single +from policyengine_core.country_template import Simulation as CountryTemplateSimulation +from policyengine_core.country_template.entities import Person +from policyengine_core.data import Dataset +from policyengine_core.model_api import Variable +from policyengine_core.periods import MONTH from policyengine_core.simulations import SimulationBuilder import policyengine_core.simulations.simulation as simulation_module from policyengine_core.simulations.simulation_macro_cache import ( @@ -6,6 +11,7 @@ ) import importlib.metadata import numpy as np +import pandas as pd from pathlib import Path @@ -112,3 +118,103 @@ def __init__(self, tax_benefit_system): simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) simulation.calculate("income_tax", "2017-01") + + +class formula_component_for_safe_export(Variable): + value_type = float + entity = Person + definition_period = MONTH + label = "Formula component for safe export tests." + + def formula(person, period): + return person("salary", period) * 0 + + +class pseudo_input_for_safe_export(Variable): + value_type = float + entity = Person + definition_period = MONTH + label = "Pseudo-input for safe export tests." + adds = ["formula_component_for_safe_export"] + + +def _safe_export_dataset(dataframe): + return Dataset.from_dataframe(dataframe, "2022-01") + + +def _safe_export_simulation(isolated_tax_benefit_system): + isolated_tax_benefit_system.add_variable(formula_component_for_safe_export) + isolated_tax_benefit_system.add_variable(pseudo_input_for_safe_export) + + dataframe = pd.DataFrame( + { + "person_id__2022": [0], + "household_id__2022": [0], + "person_household_id__2022": [0], + "person_household_role__2022": ["parent"], + "household_weight__2022": [1.0], + "salary__2022-01": [0.0], + "pseudo_input_for_safe_export__2022-01": [999.0], + } + ) + return CountryTemplateSimulation( + tax_benefit_system=isolated_tax_benefit_system, + dataset=_safe_export_dataset(dataframe), + ) + + +def test__given_pseudo_input_in_dataset__then_input_dataframe_excludes_it( + isolated_tax_benefit_system, +): + # Given + simulation = _safe_export_simulation(isolated_tax_benefit_system) + + assert simulation.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 999.0 + + # When + dataframe = simulation.to_input_dataframe() + reloaded = CountryTemplateSimulation( + tax_benefit_system=isolated_tax_benefit_system, + dataset=_safe_export_dataset(dataframe), + ) + + # Then + assert "salary__2022-01" in dataframe.columns + assert "pseudo_input_for_safe_export__2022-01" not in dataframe.columns + assert "salary" in simulation.true_input_variables + assert "pseudo_input_for_safe_export" not in simulation.true_input_variables + assert ( + "pseudo_input_for_safe_export__2022-01" + in simulation.to_input_dataframe(include_computed_variables=True).columns + ) + assert reloaded.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 0.0 + + +def test__given_pseudo_input_in_dataset__then_input_dict_h5_round_trip_excludes_it( + isolated_tax_benefit_system, tmp_path +): + # Given + simulation = _safe_export_simulation(isolated_tax_benefit_system) + exported_data = simulation.to_input_dict() + h5_path = tmp_path / "safe_export.h5" + + class SafeExportDataset(Dataset): + name = "safe_export" + label = "Safe export" + file_path = h5_path + data_format = Dataset.TIME_PERIOD_ARRAYS + + # When + SafeExportDataset().save_dataset(exported_data) + reloaded = CountryTemplateSimulation( + tax_benefit_system=isolated_tax_benefit_system, + dataset=Dataset.from_file(h5_path), + ) + + # Then + assert "salary" in exported_data + assert "pseudo_input_for_safe_export" not in exported_data + assert "pseudo_input_for_safe_export" in simulation.to_input_dict( + include_computed_variables=True + ) + assert reloaded.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 0.0 From a44b78dae2d27ef5a326f18a14c3deb955445141 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 17 May 2026 14:11:22 -0400 Subject: [PATCH 2/2] Preserve inherited inputs in safe exports --- policyengine_core/simulations/simulation.py | 12 ++++++++++- tests/core/test_simulations.py | 24 +++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 5fed0140..92e54eda 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -1624,6 +1624,15 @@ def _is_exportable_input_variable(self, variable_name: str) -> bool: variable = self.tax_benefit_system.get_variable(variable_name) return variable is not None and variable.is_input_variable() + def _get_visible_branch_names(self) -> List[str]: + branch_names = [self.branch_name] + parent = getattr(self, "parent_branch", None) + while parent is not None: + branch_names.append(parent.branch_name) + parent = getattr(parent, "parent_branch", None) + branch_names.append("default") + return list(dict.fromkeys(branch_names)) + def _get_exportable_input_periods( self, variable_name: str, @@ -1640,7 +1649,8 @@ def _get_exportable_input_periods( for input_variable_name, branch_name, period in getattr( self, "_user_input_keys", set() ) - if input_variable_name == variable_name and branch_name == self.branch_name + if input_variable_name == variable_name + and branch_name in self._get_visible_branch_names() } if not user_input_periods: return [] diff --git a/tests/core/test_simulations.py b/tests/core/test_simulations.py index fdd16f8f..5118021b 100644 --- a/tests/core/test_simulations.py +++ b/tests/core/test_simulations.py @@ -218,3 +218,27 @@ class SafeExportDataset(Dataset): include_computed_variables=True ) assert reloaded.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 0.0 + + +def test__given_branch_inherits_dataset_inputs__then_safe_exports_include_them( + isolated_tax_benefit_system, +): + # Given + simulation = _safe_export_simulation(isolated_tax_benefit_system) + branch = simulation.get_branch("reform") + + assert branch.calculate("salary", "2022-01")[0] == 0.0 + + # When + dataframe = branch.to_input_dataframe() + exported_data = branch.to_input_dict() + + # Then + assert "person_id__ETERNITY" in dataframe.columns + assert "household_id__ETERNITY" in dataframe.columns + assert "household_weight__2022" in dataframe.columns + assert "salary__2022-01" in dataframe.columns + assert "pseudo_input_for_safe_export__2022-01" not in dataframe.columns + assert "salary" in exported_data + assert "pseudo_input_for_safe_export" not in exported_data + assert "salary" in branch.true_input_variables