Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/safe-input-export.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Exclude pseudo-inputs and calculated values from simulation input exports by default.
104 changes: 99 additions & 5 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,10 +1588,92 @@ def check_macro_cache(self, variable_name: str, period: str) -> bool:

return True

def get_input_variables(self, include_computed_variables: bool = True) -> List[str]:
"""Return variable names stored as inputs on this simulation.

Args:
include_computed_variables: When ``True``, return the legacy
runtime list of variables with stored values. When ``False``,
return only structurally input variables that were populated
through ``set_input`` on the current branch.

Returns:
List[str]: Stored input variable names.
"""
if include_computed_variables:
return list(self.input_variables)

return [
variable_name
for variable_name in self.tax_benefit_system.variables
if len(
self._get_exportable_input_periods(
variable_name,
include_computed_variables=False,
)
)
> 0
]

@property
def true_input_variables(self) -> List[str]:
"""Stored variables that are safe to reload as source inputs."""
return self.get_input_variables(include_computed_variables=False)

def _is_exportable_input_variable(self, variable_name: str) -> bool:
variable = self.tax_benefit_system.get_variable(variable_name)
return variable is not None and variable.is_input_variable()

def _get_visible_branch_names(self) -> List[str]:
branch_names = [self.branch_name]
parent = getattr(self, "parent_branch", None)
while parent is not None:
branch_names.append(parent.branch_name)
parent = getattr(parent, "parent_branch", None)
branch_names.append("default")
return list(dict.fromkeys(branch_names))

def _get_exportable_input_periods(
self,
variable_name: str,
include_computed_variables: bool,
) -> List[Period]:
if include_computed_variables:
return self.get_holder(variable_name).get_known_periods()

if not self._is_exportable_input_variable(variable_name):
return []

user_input_periods = {
period
for input_variable_name, branch_name, period in getattr(
self, "_user_input_keys", set()
)
if input_variable_name == variable_name
and branch_name in self._get_visible_branch_names()
}
if not user_input_periods:
return []
variable = self.tax_benefit_system.get_variable(variable_name)
holder = self.get_holder(variable_name)
if variable.definition_period == ETERNITY:
return holder.get_known_periods()
known_periods = set(holder.get_known_periods())
return sorted(user_input_periods & known_periods, key=str)

def to_input_dataframe(
self,
include_computed_variables: bool = False,
) -> pd.DataFrame:
"""Exports a DataFrame which can be loaded back to a new Simulation to reproduce the same results.
"""Exports a DataFrame that can be loaded back into a new Simulation.

By default, only structurally input variables populated through
``set_input`` are exported. This avoids serializing pseudo-inputs and
stale calculated values that would override formulas when reloaded.

Args:
include_computed_variables: If ``True``, export every variable with
a known period, matching the historical unsafe behavior.

Returns:
pd.DataFrame: The DataFrame containing the input values.
Expand All @@ -1601,7 +1683,9 @@ def to_input_dataframe(

for variable in self.tax_benefit_system.variables:
variable_meta = self.tax_benefit_system.variables[variable]
for period in self.get_holder(variable).get_known_periods():
for period in self._get_exportable_input_periods(
variable, include_computed_variables
):
# Test if period matches entity definition period
if variable_meta.definition_period != period.unit:
continue
Expand All @@ -1611,8 +1695,16 @@ def to_input_dataframe(

return df

def to_input_dict(self) -> dict:
"""Exports a dictionary which can be loaded back to a new Simulation to reproduce the same results.
def to_input_dict(self, include_computed_variables: bool = False) -> dict:
"""Exports a dictionary that can be loaded back into a new Simulation.

By default, only structurally input variables populated through
``set_input`` are exported. This avoids serializing pseudo-inputs and
stale calculated values that would override formulas when reloaded.

Args:
include_computed_variables: If ``True``, export every variable with
a known period, matching the historical unsafe behavior.

Returns:
dict: The dictionary containing the input values.
Expand All @@ -1621,7 +1713,9 @@ def to_input_dict(self) -> dict:

for variable in self.tax_benefit_system.variables:
data[variable] = {}
for period in self.get_holder(variable).get_known_periods():
for period in self._get_exportable_input_periods(
variable, include_computed_variables
):
values = self.calculate(variable, period, map_to="person")
if values is not None:
data[variable][str(period)] = values.tolist()
Expand Down
130 changes: 130 additions & 0 deletions tests/core/test_simulations.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from policyengine_core.country_template.situation_examples import single
from policyengine_core.country_template import Simulation as CountryTemplateSimulation
from policyengine_core.country_template.entities import Person
from policyengine_core.data import Dataset
from policyengine_core.model_api import Variable
from policyengine_core.periods import MONTH
from policyengine_core.simulations import SimulationBuilder
import policyengine_core.simulations.simulation as simulation_module
from policyengine_core.simulations.simulation_macro_cache import (
SimulationMacroCache,
)
import importlib.metadata
import numpy as np
import pandas as pd
from pathlib import Path


Expand Down Expand Up @@ -112,3 +118,127 @@ def __init__(self, tax_benefit_system):
simulation = SimulationBuilder().build_default_simulation(tax_benefit_system)

simulation.calculate("income_tax", "2017-01")


class formula_component_for_safe_export(Variable):
value_type = float
entity = Person
definition_period = MONTH
label = "Formula component for safe export tests."

def formula(person, period):
return person("salary", period) * 0


class pseudo_input_for_safe_export(Variable):
value_type = float
entity = Person
definition_period = MONTH
label = "Pseudo-input for safe export tests."
adds = ["formula_component_for_safe_export"]


def _safe_export_dataset(dataframe):
return Dataset.from_dataframe(dataframe, "2022-01")


def _safe_export_simulation(isolated_tax_benefit_system):
isolated_tax_benefit_system.add_variable(formula_component_for_safe_export)
isolated_tax_benefit_system.add_variable(pseudo_input_for_safe_export)

dataframe = pd.DataFrame(
{
"person_id__2022": [0],
"household_id__2022": [0],
"person_household_id__2022": [0],
"person_household_role__2022": ["parent"],
"household_weight__2022": [1.0],
"salary__2022-01": [0.0],
"pseudo_input_for_safe_export__2022-01": [999.0],
}
)
return CountryTemplateSimulation(
tax_benefit_system=isolated_tax_benefit_system,
dataset=_safe_export_dataset(dataframe),
)


def test__given_pseudo_input_in_dataset__then_input_dataframe_excludes_it(
isolated_tax_benefit_system,
):
# Given
simulation = _safe_export_simulation(isolated_tax_benefit_system)

assert simulation.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 999.0

# When
dataframe = simulation.to_input_dataframe()
reloaded = CountryTemplateSimulation(
tax_benefit_system=isolated_tax_benefit_system,
dataset=_safe_export_dataset(dataframe),
)

# Then
assert "salary__2022-01" in dataframe.columns
assert "pseudo_input_for_safe_export__2022-01" not in dataframe.columns
assert "salary" in simulation.true_input_variables
assert "pseudo_input_for_safe_export" not in simulation.true_input_variables
assert (
"pseudo_input_for_safe_export__2022-01"
in simulation.to_input_dataframe(include_computed_variables=True).columns
)
assert reloaded.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 0.0


def test__given_pseudo_input_in_dataset__then_input_dict_h5_round_trip_excludes_it(
isolated_tax_benefit_system, tmp_path
):
# Given
simulation = _safe_export_simulation(isolated_tax_benefit_system)
exported_data = simulation.to_input_dict()
h5_path = tmp_path / "safe_export.h5"

class SafeExportDataset(Dataset):
name = "safe_export"
label = "Safe export"
file_path = h5_path
data_format = Dataset.TIME_PERIOD_ARRAYS

# When
SafeExportDataset().save_dataset(exported_data)
reloaded = CountryTemplateSimulation(
tax_benefit_system=isolated_tax_benefit_system,
dataset=Dataset.from_file(h5_path),
)

# Then
assert "salary" in exported_data
assert "pseudo_input_for_safe_export" not in exported_data
assert "pseudo_input_for_safe_export" in simulation.to_input_dict(
include_computed_variables=True
)
assert reloaded.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 0.0


def test__given_branch_inherits_dataset_inputs__then_safe_exports_include_them(
isolated_tax_benefit_system,
):
# Given
simulation = _safe_export_simulation(isolated_tax_benefit_system)
branch = simulation.get_branch("reform")

assert branch.calculate("salary", "2022-01")[0] == 0.0

# When
dataframe = branch.to_input_dataframe()
exported_data = branch.to_input_dict()

# Then
assert "person_id__ETERNITY" in dataframe.columns
assert "household_id__ETERNITY" in dataframe.columns
assert "household_weight__2022" in dataframe.columns
assert "salary__2022-01" in dataframe.columns
assert "pseudo_input_for_safe_export__2022-01" not in dataframe.columns
assert "salary" in exported_data
assert "pseudo_input_for_safe_export" not in exported_data
assert "salary" in branch.true_input_variables
Loading