Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/fix-assert-near-float32.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Compare `assert_near` operands at float32 when one operand is already float32, otherwise at float64. Keeps the H6 fix that catches dollar-level differences on multi-million-dollar float64 values, without surfacing float32 storage rounding (e.g. `8.91` stored as float32 vs the Python literal `8.91`) as spurious test failures in YAML tests against float-typed Variables.
1 change: 1 addition & 0 deletions changelog.d/fix-fast-cache-guards.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Guard all `_fast_cache` mutation sites in `Simulation` against the attribute being missing. `Simulation.__init__` initialises `self._fast_cache = {}`, but country-package subclasses (e.g. `policyengine_uk.Simulation`) legitimately override `__init__` without calling `super().__init__`, so `set_input`, `delete_arrays`, and `purge_cache_of_invalid_values` were raising `AttributeError` on those subclasses.
19 changes: 13 additions & 6 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,10 +825,12 @@ def purge_cache_of_invalid_values(self) -> None:
# We wait for the end of calculate(), signalled by an empty stack, before purging the cache
if self.tracer.stack:
return
_fast_cache = getattr(self, "_fast_cache", None)
for _name, _period in self.invalidated_caches:
holder = self.get_holder(_name)
holder.delete_arrays(_period)
self._fast_cache.pop((_name, _period), None)
if _fast_cache is not None:
_fast_cache.pop((_name, _period), None)
self.invalidated_caches = set()

def calculate_add(
Expand Down Expand Up @@ -1189,14 +1191,17 @@ def delete_arrays(self, variable: str, period: Period = None) -> None:
True
"""
self.get_holder(variable).delete_arrays(period)
_fast_cache = getattr(self, "_fast_cache", None)
if period is None:
self._fast_cache = {
k: v for k, v in self._fast_cache.items() if k[0] != variable
}
if _fast_cache is not None:
self._fast_cache = {
k: v for k, v in _fast_cache.items() if k[0] != variable
}
else:
if not isinstance(period, Period):
period = periods.period(period)
self._fast_cache.pop((variable, period), None)
if _fast_cache is not None:
_fast_cache.pop((variable, period), None)

def get_known_periods(self, variable: str) -> List[Period]:
"""
Expand Down Expand Up @@ -1241,7 +1246,9 @@ def set_input(self, variable_name: str, period: Period, value: ArrayLike) -> Non
if (variable.end is not None) and (period.start.date > variable.end):
return
self.get_holder(variable_name).set_input(period, value, self.branch_name)
self._fast_cache.pop((variable_name, period), None)
_fast_cache = getattr(self, "_fast_cache", None)
if _fast_cache is not None:
_fast_cache.pop((variable_name, period), None)

def get_variable_population(self, variable_name: str) -> Population:
variable = self.tax_benefit_system.get_variable(
Expand Down
25 changes: 18 additions & 7 deletions policyengine_core/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,24 @@ def assert_near(
if isinstance(target_value, str):
target_value = eval_expression(target_value)

# Use float64 here so we don't silently lose precision on values
# above ~16M (float32 only carries ~7 decimal digits). Under float32,
# ``25_000_001`` and ``25_000_000`` round to the same number and a
# test expecting one would pass on the other (bug H6).
target_value = np.array(target_value).astype(np.float64)

value = np.array(value).astype(np.float64)
# Choose comparison dtype:
# - Default to float64 so we don't silently lose precision on values
# above ~16M (float32 only carries ~7 decimal digits). Under
# float32, ``25_000_001`` and ``25_000_000`` round to the same
# number and a test expecting one would pass on the other (bug H6).
# - But if ``value`` is already float32 (because it came out of a
# float-typed Variable, which PolicyEngine stores as float32),
# promoting to float64 would surface the float32 rounding that's
# baked into storage — ``8.91`` stored as float32 compares unequal
# to the Python-literal ``8.91``. That rounding is not a regression
# surfaced by the H6 fix; it's a property of the storage dtype.
# Compare at float32 in that case to keep H6's coverage for real
# precision bugs (float64/int operands) without surfacing
# pre-existing float32 storage artefacts.
_value_array = np.asarray(value)
_compare_dtype = np.float32 if _value_array.dtype == np.float32 else np.float64
target_value = np.array(target_value).astype(_compare_dtype)
value = _value_array.astype(_compare_dtype)
diff = abs(target_value - value)
if absolute_error_margin is not None:
assert (diff <= absolute_error_margin).all(), (
Expand Down
25 changes: 19 additions & 6 deletions policyengine_core/tools/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,25 @@ def assert_near(
target_value = eval_expression(target_value)

try:
# Use float64 here so we don't silently lose precision on values
# above ~16M (float32 only carries ~7 decimal digits). Under float32,
# ``25_000_001`` and ``25_000_000`` round to the same number and a
# test expecting one would pass on the other (bug H6).
target_value = np.array(target_value).astype(np.float64)
value = np.array(value).astype(np.float64)
# Choose comparison dtype:
# - Default to float64 so we don't silently lose precision on
# values above ~16M (float32 only carries ~7 decimal digits).
# Under float32, ``25_000_001`` and ``25_000_000`` round to the
# same number and a test expecting one would pass on the other
# (bug H6).
# - But if ``value`` is already float32 (because it came out of a
# float-typed Variable, which PolicyEngine stores as float32),
# promoting to float64 would surface the float32 rounding that's
# baked into storage — e.g. ``8.91`` stored as float32 compares
# unequal to the Python-literal ``8.91``. That rounding is not a
# regression; it's a property of the storage dtype. Compare at
# float32 in that case so we keep the H6 fix for real precision
# bugs (float64/int operands) without surfacing pre-existing
# float32 storage artefacts.
_value_array = np.asarray(value)
_compare_dtype = np.float32 if _value_array.dtype == np.float32 else np.float64
target_value = np.array(target_value).astype(_compare_dtype)
value = _value_array.astype(_compare_dtype)
except ValueError:
# Data type not translatable to floating point, assert complete equality
assert np.array(value) == np.array(target_value), "{}{} differs from {}".format(
Expand Down
18 changes: 18 additions & 0 deletions tests/core/test_assert_near_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from __future__ import annotations

import numpy as np
import pytest

from policyengine_core.tools import assert_near
Expand All @@ -25,3 +26,20 @@ def test_assert_near_detects_large_integer_difference():
def test_assert_near_still_passes_within_margin():
# Float64 precision must not break legitimate near-equality checks.
assert_near(1.0, 1.0005, absolute_error_margin=1e-2)


def test_assert_near_accepts_float32_storage_rounding():
# PolicyEngine stores float Variables as float32. Literals like 8.91
# can't be represented exactly in float32 — they round to about
# 8.90999985. If a simulation returns the float32-rounded value and the
# YAML test expects the Python literal 8.91 with
# ``absolute_error_margin=0``, comparing at float64 would surface the
# storage rounding as a false test failure even though nothing about
# the calculation is wrong. Compare at float32 when one operand is
# already float32 so those tests keep passing, while the int/float64
# H6 case above still fails.
value = np.float32(8.91)
assert_near(value, 8.91, absolute_error_margin=0)

value_array = np.array([8.91, 12.21, 7.3], dtype=np.float32)
assert_near(value_array, [8.91, 12.21, 7.3], absolute_error_margin=0)
76 changes: 76 additions & 0 deletions tests/core/test_fast_cache_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Regression tests: cache-manipulation methods on ``Simulation`` must not
crash when ``_fast_cache`` is missing.

``Simulation.__init__`` sets ``self._fast_cache = {}`` as the first step of
initialisation. Country-package subclasses (e.g. ``policyengine_uk.Simulation``)
can legitimately override ``__init__`` without calling ``super().__init__``
— instead they set the handful of attributes they need directly. In that
case ``_fast_cache`` never gets initialised, so any cache-mutation path
that assumes the attribute exists raised ``AttributeError`` during
``build_from_single_year_dataset`` / ``set_input`` / ``delete_arrays``.

The defensive fix is to guard the bare ``.pop`` / ``.items`` / re-assign
sites the same way the read-side fast path in ``calculate()`` already does
— ``getattr(self, "_fast_cache", None)`` and skip the cache write when
it's ``None``. Core owns this protection so every downstream subclass
doesn't have to mirror the attribute.
"""

from __future__ import annotations

import types

import numpy as np
import pytest

from policyengine_core.simulations import Simulation


def _bare_simulation():
"""Create a ``Simulation`` instance that bypasses ``__init__`` entirely.

Mirrors the pattern a country subclass would hit when overriding
``__init__`` and forgetting to initialise ``_fast_cache``.
"""
return Simulation.__new__(Simulation)


def test_set_input_without_fast_cache_attribute():
sim = _bare_simulation()

# Stand-ins for the parts ``set_input`` touches — we're not exercising
# them here; we just need the cache-pop step not to crash.
sim.start_instant = None
sim.branch_name = "default"
sim.tax_benefit_system = types.SimpleNamespace(
get_variable=lambda name, check_existence=True: types.SimpleNamespace(end=None)
)
sim.get_holder = lambda name: types.SimpleNamespace(
set_input=lambda period, value, branch: None
)

# The actual assertion: this line previously raised
# ``AttributeError: 'Simulation' object has no attribute '_fast_cache'``.
sim.set_input("variable_name", "2024", [1, 2, 3])


def test_delete_arrays_without_fast_cache_attribute():
sim = _bare_simulation()
sim.get_holder = lambda name: types.SimpleNamespace(
delete_arrays=lambda period: None
)
# No _fast_cache attribute — must not crash
sim.delete_arrays("variable", period=None)
sim.delete_arrays("variable", period="2024")


def test_purge_cache_of_invalid_values_without_fast_cache_attribute():
sim = _bare_simulation()
sim.tracer = types.SimpleNamespace(stack=[])
sim.invalidated_caches = {("variable_name", "2024")}
sim.get_holder = lambda name: types.SimpleNamespace(
delete_arrays=lambda period: None
)

sim.purge_cache_of_invalid_values()
assert sim.invalidated_caches == set()
Loading
Loading