Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/fix-invalidate-preserves-user-inputs.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preserve `set_input` values across `apply_reform`. The H3 cache invalidation wiped every variable's `_memory_storage._arrays`, which also wiped user-provided dataset inputs loaded via `set_input`. Country-package subclasses calling `set_input` during construction and then applying a structural reform (the `policyengine-uk` pattern) silently lost their datasets. Now tracks `set_input` provenance and replays those values after the invalidation wipe; formula-output caches are still invalidated as before.
1 change: 1 addition & 0 deletions changelog.d/fix-nested-branch-input-inheritance.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Restore nested-branch input inheritance and cover situation-dict `set_input`. Three follow-ups on top of the `Simulation.set_input` preservation: (1) `Holder.set_input` also records `_user_input_keys` so situation-dict inputs routed through `SimulationBuilder.finalize_variables_init` survive `apply_reform`, not only inputs set via `Simulation.set_input`; (2) `Holder.get_array` walks up `simulation.parent_branch` before falling back to `default`, so a sub-branch (e.g. `no_salt` cloned from `itemizing`) still sees inputs set on its ancestor — the C1 fallback-to-`default`-only broke the country-package nested-branch pattern; (3) `GroupPopulation.clone` now passes the cloned population (not the source) to `holder.clone`, so group-entity holders on a `get_branch` clone point at the cloned simulation and branch-aware lookups resolve correctly. Unblocks `PolicyEngine/policyengine-us#8066` (the `tax_unit_itemizes` integration test crashing with `TypeError: int() argument ... not 'NoneType'` under core 3.24.x because `state_fips` got wiped, plus a follow-up infinite recursion in `tax_liability_if_itemizing` once the state_fips wipe was resolved).
38 changes: 33 additions & 5 deletions policyengine_core/holders/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,27 @@ def get_array(self, period: Period, branch_name: str = "default") -> ArrayLike:
return self.default_array()
value = self._memory_storage.get(period, branch_name)
if value is None and branch_name != "default":
# Fall back to the ``default`` branch only. Previously the fallback
# returned *any* branch that happened to have this period (the
# first one in dict-insertion order), which silently swapped
# values between unrelated branches (reform vs baseline) and
# produced wrong reform deltas. See holder.get_array bug C1.
# Walk up ``simulation.parent_branch`` so nested branches inherit
# values from their parent (e.g. a ``no_salt`` branch cloned
# from an ``itemizing`` branch still sees ``tax_unit_itemizes``
# set on the ``itemizing`` branch). Fall back to ``default``
# only if no ancestor branch has a value. Previously the
# fallback returned the first branch in dict-insertion order
# (bug C1) — silently swapping values between unrelated
# sibling branches (reform vs baseline) and producing wrong
# reform deltas. The post-C1 behavior only fell back to
# ``default``, which broke country-package nested-branch
# patterns that relied on the ancestor's input being visible.
parent = (
getattr(self.simulation, "parent_branch", None)
if self.simulation
else None
)
while parent is not None:
ancestor_value = self._memory_storage.get(period, parent.branch_name)
if ancestor_value is not None:
return ancestor_value
parent = getattr(parent, "parent_branch", None)
default_value = self._memory_storage.get(period, "default")
if default_value is not None:
return default_value
Expand Down Expand Up @@ -225,6 +241,18 @@ def set_input(
return warnings.warn(warning_message, Warning)
if self.variable.value_type in (float, int) and isinstance(array, str):
array = tools.eval_expression(array)
# Track user-provided inputs on the simulation so
# ``Simulation._invalidate_all_caches`` can preserve them across
# ``apply_reform``. ``Simulation.set_input`` also records this, but
# ``SimulationBuilder.finalize_variables_init`` (the situation-dict
# path) and country-package dataset loaders call
# ``holder.set_input`` directly, bypassing the simulation-level hook.
# Recording here covers both paths.
simulation = getattr(self, "simulation", None)
if simulation is not None:
if not hasattr(simulation, "_user_input_keys"):
simulation._user_input_keys = set()
simulation._user_input_keys.add((self.variable.name, branch_name, period))
if self.variable.set_input and period.unit != self.variable.definition_period:
return self.variable.set_input(self, period, array)
return self._set(period, array, branch_name)
Expand Down
10 changes: 9 additions & 1 deletion policyengine_core/populations/group_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,16 @@ def __call__(
def clone(self, simulation: "Simulation", members: Population) -> "GroupPopulation":
result = GroupPopulation(self.entity, members)
result.simulation = simulation
# Pass ``result`` (the cloned population) to ``holder.clone`` so the
# holder's ``.simulation`` reference points at the clone — not at
# the source. Previously this was ``holder.clone(self)``, which
# left every group-entity holder on a ``get_branch`` clone
# pointing back at its parent simulation; that broke ``branch_name``
# and ``parent_branch`` lookups for group-entity variables
# (e.g. ``tax_unit_itemizes``) on nested branches.
result._holders = {
variable: holder.clone(self) for (variable, holder) in self._holders.items()
variable: holder.clone(result)
for (variable, holder) in self._holders.items()
}
result.count = self.count
result.ids = self.ids
Expand Down
45 changes: 39 additions & 6 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ def __init__(

self.invalidated_caches = set()
self._fast_cache: dict = {}
# ``set_input`` records each (variable_name, branch_name, period) it
# populates so ``_invalidate_all_caches`` can tell user-provided
# source data apart from formula-computed caches. Without this the
# post-``apply_reform`` cache wipe would also wipe the dataset the
# simulation was loaded from.
self._user_input_keys: set[tuple[str, str, Period]] = set()
self.debug: bool = False
self.trace: bool = trace
self.tracer: SimpleTracer = SimpleTracer() if not trace else FullTracer()
Expand Down Expand Up @@ -251,23 +257,44 @@ def apply_reform(self, reform: Union[tuple, Reform]):
self._invalidate_all_caches()

def _invalidate_all_caches(self) -> None:
"""Purge every cached calculation on this simulation.
"""Purge cached formula output, preserving user-provided inputs.

Called after ``apply_reform`` and any other operation that changes
the tax-benefit system underneath an already-calculated simulation.
Also cascades into any branches created via ``get_branch`` so those
don't keep returning stale pre-reform values either.

Every (variable, branch, period) that was populated via
``set_input`` is preserved — those are source data, not stale
formula output — so a structural reform applied after dataset
load doesn't silently discard the dataset. Everything else
(formula outputs, cached short-path results, on-disk caches) is
wiped so the next ``calculate`` recomputes under the new
tax-benefit system.
"""
self._fast_cache = {}
self.invalidated_caches = set()
# Snapshot user-provided inputs before wiping so they can be
# replayed into the fresh storage. Storage keys each entry as
# f"{branch_name}:{period}"; preserve exactly those keys.
preserved: dict[str, dict[str, object]] = {}
user_input_keys = getattr(self, "_user_input_keys", None) or set()
for variable_name, branch_name, period in user_input_keys:
holder = self.get_holder(variable_name)
storage_key = f"{branch_name}:{period}"
stored_value = holder._memory_storage._arrays.get(storage_key)
if stored_value is not None:
preserved.setdefault(variable_name, {})[storage_key] = stored_value
for variable in list(self.tax_benefit_system.variables):
holder = self.get_holder(variable)
# ``Holder.delete_arrays`` with ``period=None`` wipes every
# period on both memory and disk storage. After the storage-delete
# bug fix (C2) that now respects branch_name, so wipe both.
# Wipe formula outputs and on-disk caches on both memory and
# disk storage. After the storage-delete bug fix (C2) that
# respects branch_name, so wipe both.
holder._memory_storage._arrays = {}
if holder._disk_storage is not None:
holder._disk_storage._files = {}
# Replay preserved user inputs so ``calculate`` still sees them.
for variable_name, key_to_array in preserved.items():
holder = self.get_holder(variable_name)
holder._memory_storage._arrays.update(key_to_array)
for branch in self.branches.values():
branch._invalidate_all_caches()

Expand Down Expand Up @@ -1246,6 +1273,12 @@ def set_input(self, variable_name: str, period: Period, value: ArrayLike) -> Non
if (variable.end is not None) and (period.start.date > variable.end):
return
self.get_holder(variable_name).set_input(period, value, self.branch_name)
# Lazy-init ``_user_input_keys`` so country-package subclasses that
# override ``__init__`` without calling ``super().__init__`` still
# benefit from the set-input preservation across ``apply_reform``.
if not hasattr(self, "_user_input_keys"):
self._user_input_keys = set()
self._user_input_keys.add((variable_name, self.branch_name, period))
_fast_cache = getattr(self, "_fast_cache", None)
if _fast_cache is not None:
_fast_cache.pop((variable_name, period), None)
Expand Down
170 changes: 170 additions & 0 deletions tests/core/test_apply_reform_preserves_user_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Regression test: ``apply_reform`` must not wipe user-provided inputs.

The cache-invalidation added to fix bug H3 cleared
``holder._memory_storage._arrays`` for every variable, which also wiped
values populated via ``set_input`` that pre-dated the reform. Those values
are source data, not stale formula output, and must survive
``apply_reform`` so that downstream country packages (e.g. policyengine-uk,
policyengine-us) can load a dataset, apply a structural reform, and then
calculate against the loaded data.

See:
- https://github.com/PolicyEngine/policyengine.py/issues/1628 (symptom
surfaced here — UK household-impact tests returning 0 after reform apply)
- https://github.com/PolicyEngine/policyengine-us/issues/8058 (symptom
surfaced here — US ``tax_unit_itemizes`` integration test and many
others crashing with ``TypeError: int() argument must be a string,
a bytes-like object or a real number, not 'NoneType'`` because
``state_fips`` got wiped and the downstream
``state_name``/``state_code`` chain returned ``None``)
- bug H3 in the existing ``test_apply_reform_invalidates_cache.py``
(the cache invalidation that over-reached)
"""

from __future__ import annotations

import numpy as np

from policyengine_core.model_api import Reform
from policyengine_core.country_template import situation_examples
from policyengine_core.simulations import Simulation, SimulationBuilder


def test_apply_reform_preserves_set_input_values(tax_benefit_system):
"""Values set via ``set_input`` before ``apply_reform`` must survive it.

``set_input`` is the data-load path: it populates the holder with source
values, not cached formula output. Wiping it across ``apply_reform``
would mean every country-package dataset is silently discarded whenever
a structural reform is applied during initialisation.
"""
sim = SimulationBuilder().build_from_entities(
tax_benefit_system, situation_examples.single
)
period = "2017-01"
expected_salary = np.array([5_000.0])

sim.set_input("salary", period, expected_salary)

assert sim.get_holder("salary").get_known_periods(), (
"precondition failure: set_input did not register the period"
)

class NoOpReform(Reform):
"""Reform that touches nothing; should not invalidate inputs."""

def apply(self):
pass

sim.apply_reform(NoOpReform)

assert sim.get_holder("salary").get_known_periods(), (
"apply_reform wiped salary holder — set_input values must be preserved"
)

result = sim.calculate("salary", period=period)
assert np.allclose(result, expected_salary), (
f"apply_reform lost the user-provided salary input; got {result} "
f"instead of {expected_salary}."
)


def test_apply_reform_preserves_inputs_across_multiple_variables(tax_benefit_system):
"""Every variable set via ``set_input`` must survive, not just one."""
sim = SimulationBuilder().build_from_entities(
tax_benefit_system, situation_examples.single
)
period = "2017-01"

sim.set_input("salary", period, np.array([1_234.0]))
sim.set_input("age", period, np.array([27]))

class NoOpReform(Reform):
def apply(self):
pass

sim.apply_reform(NoOpReform)

assert np.allclose(sim.calculate("salary", period=period), [1_234.0])
assert sim.calculate("age", period=period)[0] == 27


def test_apply_reform_preserves_situation_dict_inputs(tax_benefit_system):
"""Situation-dict inputs must survive ``apply_reform`` too.

``Simulation(situation=...)`` routes inputs through
``SimulationBuilder.finalize_variables_init``, which calls
``holder.set_input`` directly — bypassing ``Simulation.set_input``.
The preservation tracking must cover that path too, otherwise
country-package subclasses that build from a situation dict and then
apply a structural reform during construction (the
``policyengine-us`` pattern) silently lose every household input.
Surfaced in ``PolicyEngine/policyengine-us#8058``.
"""
situation = {
"persons": {
"Alicia": {
"salary": {"2017-01": 3_000.0},
"age": {"2017-01": 42},
}
},
"households": {
"_": {"parents": ["Alicia"]},
},
}
sim = Simulation(
tax_benefit_system=tax_benefit_system,
situation=situation,
)

assert sim.get_holder("salary").get_known_periods(), (
"precondition failure: situation dict did not register salary"
)
assert sim.get_holder("age").get_known_periods(), (
"precondition failure: situation dict did not register age"
)

class NoOpReform(Reform):
def apply(self):
pass

sim.apply_reform(NoOpReform)

# Both inputs were set through ``holder.set_input`` via the builder,
# not through ``Simulation.set_input``. They must still survive.
assert np.allclose(sim.calculate("salary", period="2017-01"), [3_000.0]), (
"apply_reform wiped the situation-dict salary input"
)
assert sim.calculate("age", period="2017-01")[0] == 42, (
"apply_reform wiped the situation-dict age input"
)


def test_apply_reform_still_invalidates_formula_caches(tax_benefit_system):
"""The H3 fix must still hold — formula output caches must be cleared.

This is a belt-and-braces test: preserving set_input values is orthogonal
to invalidating formula outputs. A reform that neutralizes a variable
must still cause subsequent ``calculate`` calls to return the new value,
not the cached pre-reform output.
"""
sim = SimulationBuilder().build_from_entities(
tax_benefit_system, situation_examples.single
)
period = "2017-01"

# Compute once to populate the formula-output cache.
before_reform = sim.calculate("basic_income", period=period)
assert before_reform[0] > 0

class NeutraliseBasicIncome(Reform):
def apply(self):
self.neutralize_variable("basic_income")

sim.apply_reform(NeutraliseBasicIncome)

after_reform = sim.calculate("basic_income", period=period)
assert after_reform[0] == 0, (
f"apply_reform did not invalidate formula cache for basic_income; "
f"got {after_reform[0]} instead of 0."
)
67 changes: 67 additions & 0 deletions tests/core/test_holder_branch_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,70 @@ def test_get_array_falls_back_to_default_branch(tax_benefit_system):
result = holder.get_array(period, "baseline")
assert result is not None
assert result[0] == 42.0


def test_get_array_falls_back_through_parent_branch_chain(tax_benefit_system):
"""Nested branches must inherit values from their parent branch.

``policyengine-us`` uses a two-level branch pattern:

1. ``tax_liability_if_itemizing`` creates an ``itemizing`` branch from
the default simulation and calls ``branch.set_input("tax_unit_itemizes", True)``.
2. Calculating ``income_tax`` on that branch reaches
``ctc_limiting_tax_liability``, which creates a ``no_salt`` sub-branch
from the ``itemizing`` branch and calls
``no_salt.calculate("income_tax_before_credits")``.

The ``no_salt`` branch must see ``tax_unit_itemizes=True`` inherited
from its parent ``itemizing`` branch — otherwise ``tax_unit_itemizes``
re-runs its formula on ``no_salt``, which calls
``tax_liability_if_itemizing`` again, creating a circular definition
/ infinite recursion. Surfaced in ``PolicyEngine/policyengine-us#8058``.
"""
sim = _build_single(tax_benefit_system)
itemizing_branch = sim.get_branch("itemizing")
no_salt_branch = itemizing_branch.get_branch("no_salt")

holder = no_salt_branch.person.get_holder("salary")
period = periods.period("2017-01")

# Simulate ``itemizing_branch.set_input("salary", ...)``: the storage
# key lives under the ``itemizing`` branch name. The cloned ``no_salt``
# holder starts with the same storage dict because ``Population.clone``
# deep-copies ``_arrays`` from the source.
holder._memory_storage.put(np.asarray([7_777.0]), period, "itemizing")

# Asking the ``no_salt`` branch for this value must walk up the
# ``parent_branch`` chain and return the itemizing branch's value.
result = holder.get_array(period, "no_salt")
assert result is not None, (
"get_array on a nested branch must fall back through parent_branch "
"to the ancestor that actually has the value"
)
assert result[0] == 7_777.0


def test_group_population_clone_sets_holder_simulation_to_clone(tax_benefit_system):
"""``GroupPopulation.clone`` must point holders at the cloned simulation.

Previously ``GroupPopulation.clone`` called ``holder.clone(self)``
(the *source* population), so every cloned holder's
``.simulation`` reference pointed back at the source simulation. That
broke branch-aware lookups: the holder thought it belonged to the
parent branch even when the clone was a nested branch, so
``parent_branch`` walks started from the wrong simulation and missed
the ancestor's inputs.
"""
sim = _build_single(tax_benefit_system)
branch = sim.get_branch("nested")

# Find a group-entity variable (household-level).
household = branch.household
holder = household.get_holder("housing_tax")

assert holder.simulation is branch, (
"GroupPopulation.clone must pass the CLONED population to "
"holder.clone so holder.simulation points at the new branch, "
"not the source simulation"
)
assert holder.simulation.branch_name == "nested"
Loading
Loading