Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions changelog.d/extra-variables-household-impact.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
`calculate_household_impact` (US and UK) and `Simulation` now accept an
`extra_variables` mapping (`{entity_name: [variable_name, ...]}`) so
callers can request variables beyond the bundled `entity_variables`
default set without monkey-patching. This unblocks benchmark suites
(e.g. `policybench`) that need variables such as
`adjusted_gross_income`, `state_agi`, `free_school_meals`, or
`is_medicaid_eligible` that the default list does not include. The
returned `USHouseholdOutput` / `UKHouseholdOutput` dicts gain the
requested keys; existing keys are unchanged.
9 changes: 9 additions & 0 deletions src/policyengine/core/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ class Simulation(BaseModel):

tax_benefit_model_version: TaxBenefitModelVersion = None

extra_variables: dict[str, list[str]] = Field(
default_factory=dict,
description=(
"Additional variables to calculate beyond the model version's "
"default entity_variables, keyed by entity name. Use when a "
"caller needs variables that are not in the bundled default set."
),
)

@model_validator(mode="after")
def _auto_construct_strategy(self) -> "Simulation":
"""Auto-construct a RowFilterStrategy from legacy filter fields.
Expand Down
13 changes: 12 additions & 1 deletion src/policyengine/tax_benefit_models/uk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,18 @@ def run(self, simulation: "Simulation") -> "Simulation":
"household": pd.DataFrame(),
}

for entity, variables in self.entity_variables.items():
combined: dict[str, list[str]] = {
entity: list(variables)
for entity, variables in self.entity_variables.items()
}
for entity, extras in (simulation.extra_variables or {}).items():
combined.setdefault(entity, [])
for var in extras:
if var not in combined[entity]:
combined[entity].append(var)
for entity, variables in combined.items():
if entity not in data:
data[entity] = pd.DataFrame()
for var in variables:
data[entity][var] = microsim.calculate(
var, period=simulation.dataset.year, map_to=entity
Expand Down
23 changes: 20 additions & 3 deletions src/policyengine/tax_benefit_models/us/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,15 @@ class USHouseholdInput(BaseModel):
def calculate_household_impact(
household_input: USHouseholdInput,
policy: Optional[Policy] = None,
extra_variables: Optional[dict[str, list[str]]] = None,
) -> USHouseholdOutput:
"""Calculate tax and benefit impacts for a single US household."""
"""Calculate tax and benefit impacts for a single US household.

``extra_variables`` is a mapping from entity name (``person``,
``tax_unit``, ``household``, etc.) to additional variable names to
compute beyond ``us_latest.entity_variables``. Useful for benchmark
suites that need variables outside the bundled default set.
"""
n_people = len(household_input.people)

# Build person data with defaults
Expand Down Expand Up @@ -148,6 +155,7 @@ def calculate_household_impact(
dataset=dataset,
tax_benefit_model_version=us_latest,
policy=policy,
extra_variables=extra_variables or {},
)
simulation.run()

Expand All @@ -161,13 +169,22 @@ def safe_convert(value):
except (ValueError, TypeError):
return str(value)

extras = extra_variables or {}

def variables_for(entity_name: str) -> list[str]:
default = list(us_latest.entity_variables.get(entity_name, []))
for var in extras.get(entity_name, []):
if var not in default:
default.append(var)
return default

def extract_entity_outputs(
entity_name: str, entity_data, n_rows: int
) -> list[dict[str, Any]]:
outputs = []
for i in range(n_rows):
row_dict = {}
for var in us_latest.entity_variables[entity_name]:
for var in variables_for(entity_name):
row_dict[var] = safe_convert(entity_data[var].iloc[i])
outputs.append(row_dict)
return outputs
Expand All @@ -182,7 +199,7 @@ def extract_entity_outputs(
tax_unit=extract_entity_outputs("tax_unit", output_data.tax_unit, 1),
household={
var: safe_convert(output_data.household[var].iloc[0])
for var in us_latest.entity_variables["household"]
for var in variables_for("household")
},
)

Expand Down
17 changes: 15 additions & 2 deletions src/policyengine/tax_benefit_models/us/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,21 @@ def run(self, simulation: "Simulation") -> "Simulation":
if target_col in id_columns:
data["person"][target_col] = person_input_df[col].values

# Then calculate non-ID, non-weight variables from simulation
for entity, variables in self.entity_variables.items():
# Then calculate non-ID, non-weight variables from simulation,
# merging the model version's default entity_variables with any
# extra variables requested on the Simulation.
combined: dict[str, list[str]] = {
entity: list(variables)
for entity, variables in self.entity_variables.items()
}
for entity, extras in (simulation.extra_variables or {}).items():
combined.setdefault(entity, [])
for var in extras:
if var not in combined[entity]:
combined[entity].append(var)
for entity, variables in combined.items():
if entity not in data:
data[entity] = pd.DataFrame()
for var in variables:
if var not in id_columns and var not in weight_columns:
data[entity][var] = microsim.calculate(
Expand Down
49 changes: 49 additions & 0 deletions tests/test_household_impact.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,52 @@ def test_input_is_json_serializable(self):
json_dict = household.model_dump()
assert isinstance(json_dict, dict)
assert "people" in json_dict


class TestExtraVariables:
"""Callers can request variables beyond the bundled entity_variables."""

def test__given_extra_tax_unit_variable__then_output_includes_it(self):
# adjusted_gross_income is not in us_latest.entity_variables["tax_unit"]
# by default; this is the class of variables benchmark suites need.
assert "adjusted_gross_income" not in us_latest.entity_variables["tax_unit"]

household = USHouseholdInput(
people=[
{
"age": 35,
"employment_income": 60000,
"is_tax_unit_head": True,
}
],
year=2026,
)
result = calculate_us_household_impact(
household,
extra_variables={"tax_unit": ["adjusted_gross_income"]},
)

assert "adjusted_gross_income" in result.tax_unit[0]
assert result.tax_unit[0]["adjusted_gross_income"] > 0

def test__given_extra_household_variable__then_output_includes_it(self):
household = USHouseholdInput(
people=[{"age": 35, "employment_income": 60000}],
year=2026,
)
result = calculate_us_household_impact(
household,
extra_variables={"household": ["household_market_income"]},
)

assert "household_market_income" in result.household

def test__given_no_extra__then_output_matches_default(self):
household = USHouseholdInput(
people=[{"age": 35, "employment_income": 60000}],
year=2026,
)
default = calculate_us_household_impact(household)
extra = calculate_us_household_impact(household, extra_variables={})

assert set(default.tax_unit[0].keys()) == set(extra.tax_unit[0].keys())
Loading