diff --git a/changelog.d/extra-variables-household-impact.added.md b/changelog.d/extra-variables-household-impact.added.md new file mode 100644 index 00000000..a0ccc408 --- /dev/null +++ b/changelog.d/extra-variables-household-impact.added.md @@ -0,0 +1,9 @@ +`calculate_household_impact` (US and UK) and `Simulation` now accept an +`extra_variables` mapping (`{entity_name: [variable_name, ...]}`) so +callers can request variables beyond the bundled `entity_variables` +default set without monkey-patching. This unblocks benchmark suites +(e.g. `policybench`) that need variables such as +`adjusted_gross_income`, `state_agi`, `free_school_meals`, or +`is_medicaid_eligible` that the default list does not include. The +returned `USHouseholdOutput` / `UKHouseholdOutput` dicts gain the +requested keys; existing keys are unchanged. diff --git a/src/policyengine/core/simulation.py b/src/policyengine/core/simulation.py index 6456e5bc..92bf0975 100644 --- a/src/policyengine/core/simulation.py +++ b/src/policyengine/core/simulation.py @@ -44,6 +44,15 @@ class Simulation(BaseModel): tax_benefit_model_version: TaxBenefitModelVersion = None + extra_variables: dict[str, list[str]] = Field( + default_factory=dict, + description=( + "Additional variables to calculate beyond the model version's " + "default entity_variables, keyed by entity name. Use when a " + "caller needs variables that are not in the bundled default set." + ), + ) + @model_validator(mode="after") def _auto_construct_strategy(self) -> "Simulation": """Auto-construct a RowFilterStrategy from legacy filter fields. diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py index 1d6711d0..ef7ebd2c 100644 --- a/src/policyengine/tax_benefit_models/uk/model.py +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -376,7 +376,18 @@ def run(self, simulation: "Simulation") -> "Simulation": "household": pd.DataFrame(), } - for entity, variables in self.entity_variables.items(): + combined: dict[str, list[str]] = { + entity: list(variables) + for entity, variables in self.entity_variables.items() + } + for entity, extras in (simulation.extra_variables or {}).items(): + combined.setdefault(entity, []) + for var in extras: + if var not in combined[entity]: + combined[entity].append(var) + for entity, variables in combined.items(): + if entity not in data: + data[entity] = pd.DataFrame() for var in variables: data[entity][var] = microsim.calculate( var, period=simulation.dataset.year, map_to=entity diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index 122ae2af..529e6e7b 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -55,8 +55,15 @@ class USHouseholdInput(BaseModel): def calculate_household_impact( household_input: USHouseholdInput, policy: Optional[Policy] = None, + extra_variables: Optional[dict[str, list[str]]] = None, ) -> USHouseholdOutput: - """Calculate tax and benefit impacts for a single US household.""" + """Calculate tax and benefit impacts for a single US household. + + ``extra_variables`` is a mapping from entity name (``person``, + ``tax_unit``, ``household``, etc.) to additional variable names to + compute beyond ``us_latest.entity_variables``. Useful for benchmark + suites that need variables outside the bundled default set. + """ n_people = len(household_input.people) # Build person data with defaults @@ -148,6 +155,7 @@ def calculate_household_impact( dataset=dataset, tax_benefit_model_version=us_latest, policy=policy, + extra_variables=extra_variables or {}, ) simulation.run() @@ -161,13 +169,22 @@ def safe_convert(value): except (ValueError, TypeError): return str(value) + extras = extra_variables or {} + + def variables_for(entity_name: str) -> list[str]: + default = list(us_latest.entity_variables.get(entity_name, [])) + for var in extras.get(entity_name, []): + if var not in default: + default.append(var) + return default + def extract_entity_outputs( entity_name: str, entity_data, n_rows: int ) -> list[dict[str, Any]]: outputs = [] for i in range(n_rows): row_dict = {} - for var in us_latest.entity_variables[entity_name]: + for var in variables_for(entity_name): row_dict[var] = safe_convert(entity_data[var].iloc[i]) outputs.append(row_dict) return outputs @@ -182,7 +199,7 @@ def extract_entity_outputs( tax_unit=extract_entity_outputs("tax_unit", output_data.tax_unit, 1), household={ var: safe_convert(output_data.household[var].iloc[0]) - for var in us_latest.entity_variables["household"] + for var in variables_for("household") }, ) diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index f5aca625..239b0c7f 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -403,8 +403,21 @@ def run(self, simulation: "Simulation") -> "Simulation": if target_col in id_columns: data["person"][target_col] = person_input_df[col].values - # Then calculate non-ID, non-weight variables from simulation - for entity, variables in self.entity_variables.items(): + # Then calculate non-ID, non-weight variables from simulation, + # merging the model version's default entity_variables with any + # extra variables requested on the Simulation. + combined: dict[str, list[str]] = { + entity: list(variables) + for entity, variables in self.entity_variables.items() + } + for entity, extras in (simulation.extra_variables or {}).items(): + combined.setdefault(entity, []) + for var in extras: + if var not in combined[entity]: + combined[entity].append(var) + for entity, variables in combined.items(): + if entity not in data: + data[entity] = pd.DataFrame() for var in variables: if var not in id_columns and var not in weight_columns: data[entity][var] = microsim.calculate( diff --git a/tests/test_household_impact.py b/tests/test_household_impact.py index 54f6ac19..e7e4bc32 100644 --- a/tests/test_household_impact.py +++ b/tests/test_household_impact.py @@ -200,3 +200,52 @@ def test_input_is_json_serializable(self): json_dict = household.model_dump() assert isinstance(json_dict, dict) assert "people" in json_dict + + +class TestExtraVariables: + """Callers can request variables beyond the bundled entity_variables.""" + + def test__given_extra_tax_unit_variable__then_output_includes_it(self): + # adjusted_gross_income is not in us_latest.entity_variables["tax_unit"] + # by default; this is the class of variables benchmark suites need. + assert "adjusted_gross_income" not in us_latest.entity_variables["tax_unit"] + + household = USHouseholdInput( + people=[ + { + "age": 35, + "employment_income": 60000, + "is_tax_unit_head": True, + } + ], + year=2026, + ) + result = calculate_us_household_impact( + household, + extra_variables={"tax_unit": ["adjusted_gross_income"]}, + ) + + assert "adjusted_gross_income" in result.tax_unit[0] + assert result.tax_unit[0]["adjusted_gross_income"] > 0 + + def test__given_extra_household_variable__then_output_includes_it(self): + household = USHouseholdInput( + people=[{"age": 35, "employment_income": 60000}], + year=2026, + ) + result = calculate_us_household_impact( + household, + extra_variables={"household": ["household_market_income"]}, + ) + + assert "household_market_income" in result.household + + def test__given_no_extra__then_output_matches_default(self): + household = USHouseholdInput( + people=[{"age": 35, "employment_income": 60000}], + year=2026, + ) + default = calculate_us_household_impact(household) + extra = calculate_us_household_impact(household, extra_variables={}) + + assert set(default.tax_unit[0].keys()) == set(extra.tax_unit[0].keys())