From 04209701ed850b0ba49537a54eaefea83c733acf Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 17 May 2026 10:05:35 -0400 Subject: [PATCH] Enforce exclusive variable computation modes --- ...llow-uprating-formula-variables.changed.md | 1 + policyengine_core/commons/formulas.py | 36 ++-- policyengine_core/variables/variable.py | 37 ++++ tests/core/test_medium_fixes.py | 21 -- tests/core/variables/test_variables.py | 196 ++++++++++++++++-- 5 files changed, 240 insertions(+), 51 deletions(-) create mode 100644 changelog.d/disallow-uprating-formula-variables.changed.md diff --git a/changelog.d/disallow-uprating-formula-variables.changed.md b/changelog.d/disallow-uprating-formula-variables.changed.md new file mode 100644 index 00000000..c8173427 --- /dev/null +++ b/changelog.d/disallow-uprating-formula-variables.changed.md @@ -0,0 +1 @@ +Disallow variable definitions that combine formula, adds/subtracts, and uprating computation modes. diff --git a/policyengine_core/commons/formulas.py b/policyengine_core/commons/formulas.py index 0017f515..21c80106 100644 --- a/policyengine_core/commons/formulas.py +++ b/policyengine_core/commons/formulas.py @@ -348,10 +348,28 @@ def uprated(by: str = None, start_year: int = 2015) -> Callable: """ def uprater(variable: Type[Variable]) -> type: - if hasattr(variable, f"formula_{start_year}"): - return variable - - formula = variable.formula if hasattr(variable, "formula") else None + formula_names = [ + name for name in variable.__dict__ if name.startswith("formula") + ] + if formula_names: + raise ValueError( + f'Variable "{variable.__name__}" uses @uprated and has a formula. ' + "Uprating is only supported for input variables; formulas " + "should handle their own time behavior explicitly." + ) + if "adds" in variable.__dict__ or "subtracts" in variable.__dict__: + raise ValueError( + f'Variable "{variable.__name__}" uses @uprated and has ' + "adds/subtracts. Uprating is only supported for input " + "variables without formula, adds/subtracts, or uprating " + "metadata." + ) + if "uprating" in variable.__dict__: + raise ValueError( + f'Variable "{variable.__name__}" uses @uprated and has ' + "uprating. Uprating is only supported for input variables " + "without formula, adds/subtracts, or uprating metadata." + ) variable.metadata = { "uprating": by, @@ -368,16 +386,6 @@ def formula_start_year(entity, period, parameters): last_year_parameter = getattr(last_year_parameter, name) uprating = current_parameter / last_year_parameter old = entity(variable.__name__, period.last_year) - # Use numpy.all on the element-wise equality with 0; Python's - # ``all(old)`` checks truthiness of each element, so a single - # non-zero value makes the guard ``False`` even when every - # other value is zero — which defeated the "no values were - # inputted" short-circuit and caused uprating to run on top - # of a formula fall-back output (bug M1). - if (formula is not None) and np.all(old == 0): - # If no values have been inputted, don't uprate and - # instead use the previous formula on the current period. - return formula(entity, period, parameters) return uprating * old formula_start_year.__name__ = f"formula_{start_year}" diff --git a/policyengine_core/variables/variable.py b/policyengine_core/variables/variable.py index dda76e0b..97944ad8 100644 --- a/policyengine_core/variables/variable.py +++ b/policyengine_core/variables/variable.py @@ -318,6 +318,8 @@ def __init__(self, baseline_variable=None): ) self.formulas = self.set_formulas(formulas_attr) + self.check_computation_modes() + if unexpected_attrs: raise ValueError( 'Unexpected attributes in definition of variable "{}": {!r}'.format( @@ -329,6 +331,41 @@ def __init__(self, baseline_variable=None): # ----- Setters used to build the variable ----- # + @property + def uprating(self): + return getattr(self, "_uprating", None) + + @uprating.setter + def uprating(self, value): + old_value = getattr(self, "_uprating", None) + self._uprating = value + if hasattr(self, "formulas"): + try: + self.check_computation_modes() + except ValueError: + self._uprating = old_value + raise + + def get_computation_modes(self): + computation_modes = [] + if self.formulas: + computation_modes.append("formula") + if self.adds is not None or self.subtracts is not None: + computation_modes.append("adds/subtracts") + if self.uprating is not None: + computation_modes.append("uprating") + return computation_modes + + def check_computation_modes(self): + computation_modes = self.get_computation_modes() + if len(computation_modes) > 1: + raise ValueError( + f'Variable "{self.name}" mixes computation modes: ' + f"{' and '.join(computation_modes)}. Variables must use at " + "most one of formula, adds/subtracts, or uprating; plain " + "input or constant variables should use none." + ) + def set( self, attributes, diff --git a/tests/core/test_medium_fixes.py b/tests/core/test_medium_fixes.py index 7457e771..cb00f60e 100644 --- a/tests/core/test_medium_fixes.py +++ b/tests/core/test_medium_fixes.py @@ -1,9 +1,5 @@ """Regression tests for a batch of surgical Medium-severity fixes. -* M1 — ``@uprated`` short-circuit used Python ``all()`` instead of - ``numpy.all(old == 0)``. Previously the guard returned ``True`` only - when the first element was zero (truthiness) rather than "no values - have been inputted". * M8 — ``SimulationBuilder`` multi-axis ``linspace`` branch divided by ``axis_count - 1``, crashing on single-point axes. * M10 — ``Dataset.download`` parsed ``release://org/repo/tag/file`` with @@ -17,7 +13,6 @@ import datetime -import numpy as np import pytest from policyengine_core.variables.config import VALUE_TYPES @@ -42,19 +37,3 @@ def test_single_point_axis_does_not_divide_by_zero(persons): # After the fix, a single-point axis produces the ``axis["min"]`` value. builder.expand_axes() assert builder.get_input("salary", "2018-11") == pytest.approx([500]) - - -def test_all_numpy_guard_triggers_on_all_zero_old(): - """Bug M1: ``np.all(old == 0)`` must be used, not Python ``all(old)``. - - Python ``all([1, 0, 0])`` == True (because 1 is truthy), so the guard - would NOT fire. ``np.all([1, 0, 0] == 0)`` == False, which correctly - says "not all zero". - """ - old = np.array([1, 0, 0]) - # Python truthy-check semantics: ``all([1, 0, 0])`` -> False because - # 0 is falsy. For the reversed test case with all zeros: - all_zero = np.array([0, 0, 0]) - # The fix uses ``np.all(old == 0)`` which is True iff every element is 0. - assert np.all(all_zero == 0) - assert not np.all(old == 0) diff --git a/tests/core/variables/test_variables.py b/tests/core/variables/test_variables.py index d98a5232..de8c0ec1 100644 --- a/tests/core/variables/test_variables.py +++ b/tests/core/variables/test_variables.py @@ -7,7 +7,7 @@ import policyengine_core.country_template as country_template import policyengine_core.country_template.situation_examples from policyengine_core.country_template.entities import Person -from policyengine_core.model_api import Variable +from policyengine_core.model_api import Variable, uprated from policyengine_core.periods import ETERNITY, MONTH from policyengine_core.simulations import SimulationBuilder from policyengine_core.tools import assert_near @@ -555,13 +555,11 @@ def formula(): def test_one_formula_one_add(): - check_error_at_add_variable( - tax_benefit_system, - variable__one_formula_one_add, - 'Variable "{name}" has a formula and an add or subtract'.format( - name="variable__one_formula_one_add" - ), - ) + with raises( + ValueError, + match='Variable "variable__one_formula_one_add" mixes computation modes: formula and adds/subtracts', + ): + tax_benefit_system.add_variable(variable__one_formula_one_add) class variable__one_formula_one_subtract(Variable): @@ -569,20 +567,144 @@ class variable__one_formula_one_subtract(Variable): entity = Person definition_period = MONTH label = "Variable with one formula and one subtract." - adds = ["pass"] + subtracts = ["pass"] def formula(): pass def test_one_formula_one_subtract(): - check_error_at_add_variable( - tax_benefit_system, - variable__one_formula_one_subtract, - 'Variable "{name}" has a formula and an add or subtract'.format( - name="variable__one_formula_one_subtract" - ), - ) + with raises( + ValueError, + match='Variable "variable__one_formula_one_subtract" mixes computation modes: formula and adds/subtracts', + ): + tax_benefit_system.add_variable(variable__one_formula_one_subtract) + + +class variable__one_formula_one_uprating(Variable): + value_type = int + entity = Person + definition_period = MONTH + label = "Variable with one formula and one uprating." + uprating = "uprating.index" + + def formula(): + pass + + +def test_one_formula_one_uprating(): + with raises( + ValueError, + match='Variable "variable__one_formula_one_uprating" mixes computation modes: formula and uprating', + ): + tax_benefit_system.add_variable(variable__one_formula_one_uprating) + + +class variable__one_add_one_uprating(Variable): + value_type = int + entity = Person + definition_period = MONTH + label = "Variable with one add and one uprating." + adds = ["pass"] + uprating = "uprating.index" + + +def test_one_add_one_uprating(): + with raises( + ValueError, + match='Variable "variable__one_add_one_uprating" mixes computation modes: adds/subtracts and uprating', + ): + tax_benefit_system.add_variable(variable__one_add_one_uprating) + + +class variable__one_subtract_one_uprating(Variable): + value_type = int + entity = Person + definition_period = MONTH + label = "Variable with one subtract and one uprating." + subtracts = ["pass"] + uprating = "uprating.index" + + +def test_one_subtract_one_uprating(): + with raises( + ValueError, + match='Variable "variable__one_subtract_one_uprating" mixes computation modes: adds/subtracts and uprating', + ): + tax_benefit_system.add_variable(variable__one_subtract_one_uprating) + + +def test_uprated_decorator_rejects_existing_formula(): + with raises( + ValueError, + match='Variable "variable__uprated_decorator_one_formula" uses @uprated and has a formula', + ): + + @uprated("uprating.index") + class variable__uprated_decorator_one_formula(Variable): + value_type = int + entity = Person + definition_period = MONTH + label = "Variable with @uprated and one formula." + + def formula(): + pass + + +def test_uprated_decorator_rejects_existing_adds(): + with raises( + ValueError, + match='Variable "variable__uprated_decorator_one_add" uses @uprated and has adds/subtracts', + ): + + @uprated("uprating.index") + class variable__uprated_decorator_one_add(Variable): + value_type = int + entity = Person + definition_period = MONTH + label = "Variable with @uprated and one add." + adds = ["pass"] + + +def test_uprated_decorator_rejects_existing_subtracts(): + with raises( + ValueError, + match='Variable "variable__uprated_decorator_one_subtract" uses @uprated and has adds/subtracts', + ): + + @uprated("uprating.index") + class variable__uprated_decorator_one_subtract(Variable): + value_type = int + entity = Person + definition_period = MONTH + label = "Variable with @uprated and one subtract." + subtracts = ["pass"] + + +def test_uprated_decorator_rejects_existing_uprating(): + with raises( + ValueError, + match='Variable "variable__uprated_decorator_one_uprating" uses @uprated and has uprating', + ): + + @uprated("uprating.index") + class variable__uprated_decorator_one_uprating(Variable): + value_type = int + entity = Person + definition_period = MONTH + label = "Variable with @uprated and one uprating." + uprating = "uprating.index" + + +def test_uprated_decorator_allows_input_variable(): + @uprated("uprating.index") + class variable__uprated_decorator_input(Variable): + value_type = int + entity = Person + definition_period = MONTH + label = "Input variable with @uprated." + + assert hasattr(variable__uprated_decorator_input, "formula_2015") class variable__one_formula(Variable): @@ -629,6 +751,48 @@ def test_one_subtract(): assert len(variable.subtracts) +class variable__one_add_one_subtract(Variable): + value_type = int + entity = Person + definition_period = MONTH + label = "Variable with one add and one subtract." + adds = ["pass"] + subtracts = ["pass"] + + +def test_one_add_one_subtract(): + tax_benefit_system.add_variable(variable__one_add_one_subtract) + variable = tax_benefit_system.variables["variable__one_add_one_subtract"] + assert len(variable.adds) + assert len(variable.subtracts) + + +def test_runtime_uprating_assignment_rejects_existing_adds(): + variable = variable__one_add() + + with raises( + ValueError, + match='Variable "variable__one_add" mixes computation modes: adds/subtracts and uprating', + ): + variable.uprating = "uprating.index" + assert variable.uprating is None + + +class variable__runtime_uprating_input(Variable): + value_type = int + entity = Person + definition_period = MONTH + label = "Input variable with runtime uprating assignment." + + +def test_runtime_uprating_assignment_allows_input_variable(): + variable = variable__runtime_uprating_input() + + variable.uprating = "uprating.index" + + assert variable.uprating == "uprating.index" + + class variable__no_label(Variable): value_type = int entity = Person