From f8dca1371c05ff60e39a3194b4c6b5a0a8251f62 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 18 Apr 2026 16:14:58 -0400 Subject: [PATCH] Let callers request extra variables from calculate_household_impact MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an `extra_variables` mapping (`{entity_name: [variable_name,...]}`) to: - policyengine.core.Simulation - policyengine.tax_benefit_models.us.calculate_household_impact - policyengine.tax_benefit_models.uk.calculate_household_impact - PolicyEngineUSLatest.run / PolicyEngineUKLatest.run Previously both simulations computed only the hardcoded `self.entity_variables` list (a small default curated for UI/API views) and returned outputs missing anything else — including adjusted_gross_income, state_agi, free_school_meals, is_medicaid_eligible, income_tax_refundable_credits, state_refundable_credits, state_income_tax_before_refundable_credits, and every other variable a benchmark suite typically needs. The fix threads a simulation-scoped override through the run path. Extra variables are merged with the defaults (dedup; order preserved for the defaults) so existing callers see no change, and the returned USHouseholdOutput / UKHouseholdOutput dicts gain the requested keys. Unblocks the policybench migration to policyengine.py: callers can now do calculate_household_impact( input, extra_variables={"tax_unit": ["adjusted_gross_income", ...]}, ) without monkey-patching us_latest.entity_variables. Tests: three new tests in test_household_impact.py cover the tax_unit and household paths plus the no-op behaviour when extra_variables is empty/omitted. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../extra-variables-household-impact.added.md | 9 ++++ src/policyengine/core/simulation.py | 9 ++++ .../tax_benefit_models/uk/model.py | 13 ++++- .../tax_benefit_models/us/analysis.py | 23 +++++++-- .../tax_benefit_models/us/model.py | 17 ++++++- tests/test_household_impact.py | 49 +++++++++++++++++++ 6 files changed, 114 insertions(+), 6 deletions(-) create mode 100644 changelog.d/extra-variables-household-impact.added.md diff --git a/changelog.d/extra-variables-household-impact.added.md b/changelog.d/extra-variables-household-impact.added.md new file mode 100644 index 00000000..a0ccc408 --- /dev/null +++ b/changelog.d/extra-variables-household-impact.added.md @@ -0,0 +1,9 @@ +`calculate_household_impact` (US and UK) and `Simulation` now accept an +`extra_variables` mapping (`{entity_name: [variable_name, ...]}`) so +callers can request variables beyond the bundled `entity_variables` +default set without monkey-patching. This unblocks benchmark suites +(e.g. `policybench`) that need variables such as +`adjusted_gross_income`, `state_agi`, `free_school_meals`, or +`is_medicaid_eligible` that the default list does not include. The +returned `USHouseholdOutput` / `UKHouseholdOutput` dicts gain the +requested keys; existing keys are unchanged. diff --git a/src/policyengine/core/simulation.py b/src/policyengine/core/simulation.py index 6456e5bc..92bf0975 100644 --- a/src/policyengine/core/simulation.py +++ b/src/policyengine/core/simulation.py @@ -44,6 +44,15 @@ class Simulation(BaseModel): tax_benefit_model_version: TaxBenefitModelVersion = None + extra_variables: dict[str, list[str]] = Field( + default_factory=dict, + description=( + "Additional variables to calculate beyond the model version's " + "default entity_variables, keyed by entity name. Use when a " + "caller needs variables that are not in the bundled default set." + ), + ) + @model_validator(mode="after") def _auto_construct_strategy(self) -> "Simulation": """Auto-construct a RowFilterStrategy from legacy filter fields. diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py index 1d6711d0..ef7ebd2c 100644 --- a/src/policyengine/tax_benefit_models/uk/model.py +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -376,7 +376,18 @@ def run(self, simulation: "Simulation") -> "Simulation": "household": pd.DataFrame(), } - for entity, variables in self.entity_variables.items(): + combined: dict[str, list[str]] = { + entity: list(variables) + for entity, variables in self.entity_variables.items() + } + for entity, extras in (simulation.extra_variables or {}).items(): + combined.setdefault(entity, []) + for var in extras: + if var not in combined[entity]: + combined[entity].append(var) + for entity, variables in combined.items(): + if entity not in data: + data[entity] = pd.DataFrame() for var in variables: data[entity][var] = microsim.calculate( var, period=simulation.dataset.year, map_to=entity diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index 122ae2af..529e6e7b 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -55,8 +55,15 @@ class USHouseholdInput(BaseModel): def calculate_household_impact( household_input: USHouseholdInput, policy: Optional[Policy] = None, + extra_variables: Optional[dict[str, list[str]]] = None, ) -> USHouseholdOutput: - """Calculate tax and benefit impacts for a single US household.""" + """Calculate tax and benefit impacts for a single US household. + + ``extra_variables`` is a mapping from entity name (``person``, + ``tax_unit``, ``household``, etc.) to additional variable names to + compute beyond ``us_latest.entity_variables``. Useful for benchmark + suites that need variables outside the bundled default set. + """ n_people = len(household_input.people) # Build person data with defaults @@ -148,6 +155,7 @@ def calculate_household_impact( dataset=dataset, tax_benefit_model_version=us_latest, policy=policy, + extra_variables=extra_variables or {}, ) simulation.run() @@ -161,13 +169,22 @@ def safe_convert(value): except (ValueError, TypeError): return str(value) + extras = extra_variables or {} + + def variables_for(entity_name: str) -> list[str]: + default = list(us_latest.entity_variables.get(entity_name, [])) + for var in extras.get(entity_name, []): + if var not in default: + default.append(var) + return default + def extract_entity_outputs( entity_name: str, entity_data, n_rows: int ) -> list[dict[str, Any]]: outputs = [] for i in range(n_rows): row_dict = {} - for var in us_latest.entity_variables[entity_name]: + for var in variables_for(entity_name): row_dict[var] = safe_convert(entity_data[var].iloc[i]) outputs.append(row_dict) return outputs @@ -182,7 +199,7 @@ def extract_entity_outputs( tax_unit=extract_entity_outputs("tax_unit", output_data.tax_unit, 1), household={ var: safe_convert(output_data.household[var].iloc[0]) - for var in us_latest.entity_variables["household"] + for var in variables_for("household") }, ) diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index f5aca625..239b0c7f 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -403,8 +403,21 @@ def run(self, simulation: "Simulation") -> "Simulation": if target_col in id_columns: data["person"][target_col] = person_input_df[col].values - # Then calculate non-ID, non-weight variables from simulation - for entity, variables in self.entity_variables.items(): + # Then calculate non-ID, non-weight variables from simulation, + # merging the model version's default entity_variables with any + # extra variables requested on the Simulation. + combined: dict[str, list[str]] = { + entity: list(variables) + for entity, variables in self.entity_variables.items() + } + for entity, extras in (simulation.extra_variables or {}).items(): + combined.setdefault(entity, []) + for var in extras: + if var not in combined[entity]: + combined[entity].append(var) + for entity, variables in combined.items(): + if entity not in data: + data[entity] = pd.DataFrame() for var in variables: if var not in id_columns and var not in weight_columns: data[entity][var] = microsim.calculate( diff --git a/tests/test_household_impact.py b/tests/test_household_impact.py index 54f6ac19..e7e4bc32 100644 --- a/tests/test_household_impact.py +++ b/tests/test_household_impact.py @@ -200,3 +200,52 @@ def test_input_is_json_serializable(self): json_dict = household.model_dump() assert isinstance(json_dict, dict) assert "people" in json_dict + + +class TestExtraVariables: + """Callers can request variables beyond the bundled entity_variables.""" + + def test__given_extra_tax_unit_variable__then_output_includes_it(self): + # adjusted_gross_income is not in us_latest.entity_variables["tax_unit"] + # by default; this is the class of variables benchmark suites need. + assert "adjusted_gross_income" not in us_latest.entity_variables["tax_unit"] + + household = USHouseholdInput( + people=[ + { + "age": 35, + "employment_income": 60000, + "is_tax_unit_head": True, + } + ], + year=2026, + ) + result = calculate_us_household_impact( + household, + extra_variables={"tax_unit": ["adjusted_gross_income"]}, + ) + + assert "adjusted_gross_income" in result.tax_unit[0] + assert result.tax_unit[0]["adjusted_gross_income"] > 0 + + def test__given_extra_household_variable__then_output_includes_it(self): + household = USHouseholdInput( + people=[{"age": 35, "employment_income": 60000}], + year=2026, + ) + result = calculate_us_household_impact( + household, + extra_variables={"household": ["household_market_income"]}, + ) + + assert "household_market_income" in result.household + + def test__given_no_extra__then_output_matches_default(self): + household = USHouseholdInput( + people=[{"age": 35, "employment_income": 60000}], + year=2026, + ) + default = calculate_us_household_impact(household) + extra = calculate_us_household_impact(household, extra_variables={}) + + assert set(default.tax_unit[0].keys()) == set(extra.tax_unit[0].keys())