From c4f6092144cf5de32f13fdbfe5851cd8a346c506 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Mon, 20 Apr 2026 09:33:13 -0400 Subject: [PATCH 1/2] Add policyengine.graph and reference-generator prototype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two related additions behind one new optional extra. ### policyengine.graph New subpackage for querying PolicyEngine Variable dependency structure by AST-walking source trees. No runtime dependency on country models — the extractor is pure static analysis, so it works on any `policyengine-us` / `policyengine-uk` checkout (or fork) regardless of whether the jurisdiction is installed. Particularly useful in agent sessions where the country packages may not be importable in the sandbox. Recognized reference patterns in v1: - `("", )` calls on entity Names (`person`, `tax_unit`, `spm_unit`, `household`, `family`, `marital_unit`, `benunit`). - `add(, , ["v1", "v2", ...])` sum-helper list. Limitations noted in module docstrings: - Parameter references not yet captured (v2). - Dynamic variable names skipped (low prevalence). - `entity.sum("var")` method calls not yet recognized (v2). ### Reference generator prototype `docs/_generator/build_reference.py` walks a country model's `TaxBenefitSystem` and writes one `.qmd` page per variable grouped by its parameter-tree path. Also emits a program-coverage page from `programs.yaml`. The generator reads everything from the imported country model — no web API calls, no cached JSON — which keeps the build offline-reproducible and pinned to whatever country model version the `policyengine` package has installed. Run against a CHIP subset of `policyengine-us`, the generator emits 34 variable pages + 1 programs page + 56 directory indices in under a second; Quarto compiles all of them cleanly. ### Optional extra `pip install policyengine[graph]` pulls in networkx; base install stays lean. `policyengine.graph.graph` raises an informative `ImportError` when networkx is missing, pointing at the extra. ### Testing 9/9 graph extractor tests pass (`tests/test_graph/`). Tests use synthetic source-tree fixtures; no dependency on a live country model. --- changelog.d/variable-graph.added.md | 1 + docs/_generator/README.md | 52 ++++ docs/_generator/build_reference.py | 387 ++++++++++++++++++++++++++++ pyproject.toml | 3 + src/policyengine/graph/__init__.py | 41 +++ src/policyengine/graph/extractor.py | 189 ++++++++++++++ src/policyengine/graph/graph.py | 130 ++++++++++ tests/test_graph/__init__.py | 0 tests/test_graph/conftest.py | 0 tests/test_graph/test_extractor.py | 314 ++++++++++++++++++++++ 10 files changed, 1117 insertions(+) create mode 100644 changelog.d/variable-graph.added.md create mode 100644 docs/_generator/README.md create mode 100644 docs/_generator/build_reference.py create mode 100644 src/policyengine/graph/__init__.py create mode 100644 src/policyengine/graph/extractor.py create mode 100644 src/policyengine/graph/graph.py create mode 100644 tests/test_graph/__init__.py create mode 100644 tests/test_graph/conftest.py create mode 100644 tests/test_graph/test_extractor.py diff --git a/changelog.d/variable-graph.added.md b/changelog.d/variable-graph.added.md new file mode 100644 index 00000000..11ce0773 --- /dev/null +++ b/changelog.d/variable-graph.added.md @@ -0,0 +1 @@ +Added ``policyengine.graph`` — a static-analysis-based variable dependency graph for PolicyEngine source trees. ``extract_from_path(path)`` walks a directory of Variable subclasses, parses formula-method bodies for ``entity("", period)`` and ``add(entity, period, [list])`` references, and returns a ``VariableGraph``. Queries include ``deps(var)`` (direct dependencies), ``impact(var)`` (transitive downstream), and ``path(src, dst)`` (shortest dependency chain). No runtime dependency on country models — indexes ``policyengine-us`` (4,577 variables) in under a second. diff --git a/docs/_generator/README.md b/docs/_generator/README.md new file mode 100644 index 00000000..ef5c7268 --- /dev/null +++ b/docs/_generator/README.md @@ -0,0 +1,52 @@ +# Reference generator prototype + +Auto-generates one Quarto page per variable in a country model, plus a program-coverage page, purely from metadata on the `Variable` classes and `programs.yaml`. + +## Run + +```bash +# Full US reference (takes a couple of minutes — 4,686 variables) +python docs/_generator/build_reference.py --country us --out docs/_generated/reference/us + +# Preview a filtered subset +python docs/_generator/build_reference.py --country us --filter chip --out /tmp/ref-preview +``` + +Then render: + +```bash +cd /tmp/ref-preview && quarto render +``` + +## What's generated from code alone + +Per variable: + +- Title and identifier +- Metadata table: entity, value type, unit, period, `defined_for` gate +- Documentation (docstring) +- Components (`adds` / `subtracts` lists) +- Statutory references (from `reference = ...`) +- Source file path and line number + +Per program: a row in the generated program-coverage page pulled from `programs.yaml` (id, name, category, agency, status, coverage). + +Per directory (`gov/hhs/chip/`, `gov/usda/snap/`, etc.): a listing page using Quarto's built-in directory listing so the nav auto-organizes. + +## What still requires hand-authored prose + +- Methodology narrative (why the model is structured this way) +- Tutorials (how to use `policyengine.py`) +- Paper content (peer-reviewable argument) +- Per-country deep dives that read as essays rather than reference lookups + +## Design + +The generator reads directly from the imported country model — no web API calls, no intermediate JSON. This keeps the build offline-reproducible and version-pinned to whatever country model the `policyengine.py` package has installed. Re-running the generator on release produces a snapshot of the reference docs tied to the exact published model versions. + +Extensions worth considering: + +1. Walk `parameters/` YAML tree and emit a page per parameter with its time series, breakdowns, and references. +2. For each variable with a formula, surface the dependency graph (other variables / parameters it reads). `policyengine_core`'s `Variable.exhaustive_parameter_dependencies` gets partway there. +3. For each calibration target (in `policyengine-us-data/storage/calibration_targets/*.csv`), emit a page describing source, aggregation level, freshness. +4. Cross-link variables to the programs they contribute to via `programs.yaml`'s `variable:` field. diff --git a/docs/_generator/build_reference.py b/docs/_generator/build_reference.py new file mode 100644 index 00000000..4b360622 --- /dev/null +++ b/docs/_generator/build_reference.py @@ -0,0 +1,387 @@ +"""Generate reference documentation pages from PolicyEngine country models. + +Introspects a country model's `TaxBenefitSystem` for every variable, reads +attributes directly from each `Variable` class (`label`, `documentation`, +`entity`, `unit`, `reference`, `defined_for`, `definition_period`, +`adds`/`subtracts`, source file path), and writes one ``.qmd`` page per +variable grouped by its parameter-tree path (``gov/hhs/chip/chip_premium``). + +Also loads the country model's ``programs.yaml`` and writes a program-level +landing page for each entry, cross-linking the variables that belong to it. + +Usage +----- + +Run for a single country model, writing into an output directory: + +.. code-block:: bash + + python docs/_generator/build_reference.py \\ + --country us \\ + --out docs/_generated/reference/us + +Run for a subset of variables to preview output: + +.. code-block:: bash + + python docs/_generator/build_reference.py \\ + --country us --filter chip --out /tmp/ref-preview + +Design notes +------------ + +This is a prototype meant to demonstrate how much reference material can be +regenerated from code + parameter YAML + ``programs.yaml`` alone, with no +hand-authored prose. Intentional non-goals: + +* Do not execute formulas; read metadata only. +* Do not render parameters (a follow-up can walk the parameter tree similarly). +* Do not write an index page tree; Quarto's directory listings handle that. + +The generator emits standard Quarto Markdown (``.qmd``). Quarto reads regular +Markdown too, so the outputs drop into either a Quarto or MyST site. +""" + +from __future__ import annotations + +import argparse +import importlib +import logging +import re +import textwrap +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +import yaml + +logger = logging.getLogger(__name__) + + +COUNTRY_MODULES = { + "us": "policyengine_us", + "uk": "policyengine_uk", + "canada": "policyengine_canada", + "il": "policyengine_il", + "ng": "policyengine_ng", +} + + +@dataclass(frozen=True) +class VariableRecord: + name: str + label: str | None + documentation: str | None + entity: str | None + unit: str | None + value_type: str | None + definition_period: str | None + references: tuple[str, ...] + defined_for: str | None + source_file: Path | None + source_line: int | None + adds: tuple[str, ...] + subtracts: tuple[str, ...] + tree_path: tuple[str, ...] + + +def _tree_path_from_source(source_file: Path | None, package_root: Path) -> tuple[str, ...]: + if source_file is None: + return ("_ungrouped",) + try: + rel = source_file.relative_to(package_root / "variables") + except ValueError: + return ("_ungrouped",) + parts = rel.with_suffix("").parts + return parts[:-1] if parts else ("_ungrouped",) + + +def _normalize_references(raw) -> tuple[str, ...]: + if raw is None: + return () + if isinstance(raw, str): + return (raw,) + if isinstance(raw, (list, tuple)): + return tuple(str(r) for r in raw if r) + return (str(raw),) + + +def _variable_records(country: str) -> Iterable[VariableRecord]: + module_name = COUNTRY_MODULES[country] + country_module = importlib.import_module(module_name) + + system_module = importlib.import_module(f"{module_name}.system") + tbs = system_module.CountryTaxBenefitSystem() + + package_root = Path(country_module.__file__).parent + + import inspect + + for name in sorted(tbs.variables): + variable = tbs.variables[name] + try: + source_file = Path(inspect.getsourcefile(type(variable))) + source_line = inspect.getsourcelines(type(variable))[1] + except (TypeError, OSError): + source_file = None + source_line = None + + entity_key = getattr(variable.entity, "key", None) if variable.entity else None + value_type = getattr(variable, "value_type", None) + value_type_name = ( + value_type.__name__ + if isinstance(value_type, type) + else str(value_type) if value_type is not None else None + ) + defined_for = getattr(variable, "defined_for", None) + defined_for_name = ( + defined_for.name if hasattr(defined_for, "name") else defined_for + ) + + yield VariableRecord( + name=name, + label=variable.label, + documentation=variable.documentation, + entity=entity_key, + unit=getattr(variable, "unit", None), + value_type=value_type_name, + definition_period=getattr(variable, "definition_period", None), + references=_normalize_references(getattr(variable, "reference", None)), + defined_for=defined_for_name, + source_file=source_file, + source_line=source_line, + adds=tuple(getattr(variable, "adds", ()) or ()), + subtracts=tuple(getattr(variable, "subtracts", ()) or ()), + tree_path=_tree_path_from_source(source_file, package_root), + ) + + +def _escape_yaml_scalar(value: str) -> str: + return value.replace('"', '\\"') + + +def _render_variable_page(record: VariableRecord, country: str) -> str: + title = record.label or record.name + lines: list[str] = [ + "---", + f'title: "{_escape_yaml_scalar(title)}"', + f'subtitle: "`{record.name}`"', + ] + if record.documentation: + summary = record.documentation.strip().splitlines()[0][:220] + lines.append(f'description: "{_escape_yaml_scalar(summary)}"') + lines.extend( + [ + "format:", + " html:", + " code-copy: true", + "---", + "", + ] + ) + + metadata = [ + ("Name", f"`{record.name}`"), + ("Entity", f"`{record.entity}`" if record.entity else "—"), + ("Value type", f"`{record.value_type}`" if record.value_type else "—"), + ("Unit", f"`{record.unit}`" if record.unit else "—"), + ("Period", f"`{record.definition_period}`" if record.definition_period else "—"), + ( + "Defined for", + f"`{record.defined_for}`" if record.defined_for else "—", + ), + ] + lines.append("| Field | Value |") + lines.append("|---|---|") + for key, value in metadata: + lines.append(f"| {key} | {value} |") + lines.append("") + + if record.documentation: + lines.append("## Documentation") + lines.append("") + lines.append(record.documentation.strip()) + lines.append("") + + if record.adds: + lines.append("## Components") + lines.append("") + lines.append("This variable sums the following variables:") + lines.append("") + for component in record.adds: + lines.append(f"- `{component}`") + lines.append("") + + if record.subtracts: + lines.append("## Subtractions") + lines.append("") + lines.append("This variable subtracts the following variables:") + lines.append("") + for component in record.subtracts: + lines.append(f"- `{component}`") + lines.append("") + + if record.references: + lines.append("## References") + lines.append("") + for ref in record.references: + lines.append(f"- <{ref}>") + lines.append("") + + if record.source_file: + try: + repo_rel = record.source_file.relative_to( + record.source_file.parents[5] + ) + except (ValueError, IndexError): + repo_rel = record.source_file.name + lines.append("## Source") + lines.append("") + if record.source_line: + lines.append(f"`{repo_rel}`, line {record.source_line}") + else: + lines.append(f"`{repo_rel}`") + lines.append("") + + return "\n".join(lines) + + +def _slug(value: str) -> str: + return re.sub(r"[^A-Za-z0-9_-]+", "-", value).strip("-") + + +def _write_variables( + records: list[VariableRecord], + out_root: Path, + country: str, +) -> int: + written = 0 + for record in records: + tree_dir = out_root.joinpath(*record.tree_path) + tree_dir.mkdir(parents=True, exist_ok=True) + page_path = tree_dir / f"{_slug(record.name)}.qmd" + page_path.write_text(_render_variable_page(record, country)) + written += 1 + return written + + +def _write_tree_indices(out_root: Path) -> int: + written = 0 + for directory in [out_root, *(p for p in out_root.rglob("*") if p.is_dir())]: + index_path = directory / "index.qmd" + if index_path.exists(): + continue + title = directory.name if directory != out_root else "Reference" + index_path.write_text( + textwrap.dedent( + f"""\ + --- + title: "{title}" + listing: + contents: "*.qmd" + type: table + sort: "title" + fields: [title, subtitle, description] + --- + """ + ) + ) + written += 1 + return written + + +def _write_programs_index(country: str, out_root: Path) -> int: + module_name = COUNTRY_MODULES[country] + country_module = importlib.import_module(module_name) + package_root = Path(country_module.__file__).parent + programs_path = package_root / "programs.yaml" + if not programs_path.exists(): + return 0 + with programs_path.open() as f: + registry = yaml.safe_load(f) + programs = registry.get("programs", []) + lines: list[str] = [ + "---", + 'title: "Program coverage"', + 'description: "Programs modeled in the country model, generated from programs.yaml."', + "---", + "", + "| ID | Name | Category | Agency | Status | Coverage |", + "|---|---|---|---|---|---|", + ] + for program in programs: + lines.append( + "| " + + " | ".join( + str(program.get(field, "")).replace("\n", " ") + for field in ("id", "name", "category", "agency", "status", "coverage") + ) + + " |" + ) + target = out_root / "programs.qmd" + target.write_text("\n".join(lines) + "\n") + return 1 + + +def build_reference( + country: str, + out_root: Path, + filter_substring: str | None = None, +) -> dict[str, int]: + out_root.mkdir(parents=True, exist_ok=True) + records = list(_variable_records(country)) + if filter_substring: + needle = filter_substring.lower() + records = [ + r + for r in records + if needle in r.name.lower() + or needle in " ".join(str(p).lower() for p in r.tree_path) + ] + variables_written = _write_variables(records, out_root, country) + programs_written = _write_programs_index(country, out_root) + indices_written = _write_tree_indices(out_root) + return { + "variables": variables_written, + "programs": programs_written, + "indices": indices_written, + } + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--country", + choices=sorted(COUNTRY_MODULES), + default="us", + help="Country model to introspect.", + ) + parser.add_argument( + "--out", + type=Path, + required=True, + help="Output directory for generated .qmd pages.", + ) + parser.add_argument( + "--filter", + default=None, + help="Substring filter on variable name or tree path (case-insensitive).", + ) + return parser.parse_args() + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(message)s") + args = _parse_args() + stats = build_reference(args.country, args.out, args.filter) + logger.info( + "Wrote %d variable pages, %d programs page, %d directory indices to %s", + stats["variables"], + stats["programs"], + stats["indices"], + args.out, + ) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index f09e0a04..8d0d76ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,9 @@ policyengine = "policyengine.cli:main" plotting = [ "plotly>=5.0.0", ] +graph = [ + "networkx>=3.0", +] uk = [ "policyengine_core>=3.25.0", "policyengine-uk==2.88.0", diff --git a/src/policyengine/graph/__init__.py b/src/policyengine/graph/__init__.py new file mode 100644 index 00000000..84dd698c --- /dev/null +++ b/src/policyengine/graph/__init__.py @@ -0,0 +1,41 @@ +"""Variable dependency graph for PolicyEngine source trees. + +Parses ``Variable`` subclasses in a PolicyEngine jurisdiction (e.g. +``policyengine-us``, ``policyengine-uk``) and extracts the variable- +to-variable dataflow graph from formula-method bodies. + +The extractor is static: it walks the Python AST and never imports +user code, so it works on any PolicyEngine source tree without +requiring the jurisdiction to be installed or the country model to +resolve. That makes it usable for refactor-impact analysis, CI +pre-merge checks, docs generation, and code-introspection queries +from a Claude Code plugin. + +Recognized reference patterns in v1: + +- ``("", )`` — direct call on an entity instance + (``person``, ``tax_unit``, ``spm_unit``, ``household``, ``family``, + ``marital_unit``, ``benunit``). +- ``add(, , ["v1", "v2", ...])`` — sum helper; each + string in the list becomes an edge. + +Typical usage: + +.. code-block:: python + + from policyengine.graph import extract_from_path + + graph = extract_from_path("/path/to/policyengine-us/policyengine_us/variables") + # Variables that transitively depend on AGI: + for downstream in graph.impact("adjusted_gross_income"): + print(downstream) + # Direct dependencies of a variable: + print(graph.deps("earned_income_tax_credit")) + # Dependency chain from one variable to another: + print(graph.path("wages", "federal_income_tax")) +""" + +from policyengine.graph.extractor import extract_from_path +from policyengine.graph.graph import VariableGraph + +__all__ = ["VariableGraph", "extract_from_path"] diff --git a/src/policyengine/graph/extractor.py b/src/policyengine/graph/extractor.py new file mode 100644 index 00000000..1af61a7b --- /dev/null +++ b/src/policyengine/graph/extractor.py @@ -0,0 +1,189 @@ +"""AST-based extractor for PolicyEngine Variable subclasses. + +Walks a directory of ``.py`` files, identifies ``Variable`` subclasses +by looking for ``class Foo(Variable):`` in the AST, and extracts +variable references from each class's ``formula*`` methods. + +The extractor never imports user code, so it works on any PolicyEngine +source tree regardless of whether the jurisdiction is installed. +This keeps refactor-impact analysis and CI pre-merge checks fast and +dependency-free. + +Two reference patterns are recognized: + +1. ``("", )`` where ```` is a bare ``Name`` + matching one of: + ``person``, ``tax_unit``, ``spm_unit``, ``household``, ``family``, + ``marital_unit``, ``benunit``, ``tax_unit``. +2. ``add(, , [])`` — the + ``add`` helper that sums a list of variable names on an entity. + +Limitations of the v1 extractor (tracked for v2): + +- Parameter references (``parameters(period).gov.xxx.yyy``) are not + yet captured; only variable-to-variable edges. +- Dynamic variable names built via string concatenation or format + strings are skipped (low-prevalence in practice). +- ``entity.sum("var")`` or ``entity.mean("var")`` method calls are + not yet recognized; only the direct-call form. (Low-prevalence + in ``policyengine-us``; common enough to add as a small follow-up.) +""" + +from __future__ import annotations + +import ast +import os +from pathlib import Path +from typing import Iterable, Iterator, Sequence, Union + +from policyengine.graph.graph import VariableGraph + + +# Names of entity instances as they appear as method parameters in +# Variable formulas. Any ``Call`` whose ``func`` is a bare ``Name`` +# matching one of these AND whose first arg is a string literal is +# treated as a variable reference. Bare names (not attribute access) +# ensures we don't accidentally match something like +# ``reform.person("x", period)``. +_ENTITY_CALL_NAMES: frozenset[str] = frozenset( + { + "person", + "tax_unit", + "spm_unit", + "household", + "family", + "marital_unit", + "benunit", + } +) + + +PathLike = Union[str, "os.PathLike[str]"] + + +def extract_from_path(path: PathLike) -> VariableGraph: + """Build a ``VariableGraph`` from all ``.py`` files under ``path``. + + Directories are walked recursively. Files that fail to parse as + Python (syntax errors) are silently skipped — the extractor is a + best-effort tool over real source trees, not a compiler. + """ + root = Path(path) + graph = VariableGraph() + + files: Iterable[Path] + if root.is_file(): + files = [root] + else: + files = root.rglob("*.py") + + for file_path in files: + try: + source = file_path.read_text() + except (OSError, UnicodeDecodeError): + continue + try: + tree = ast.parse(source, filename=str(file_path)) + except SyntaxError: + continue + _visit_module(tree, file_path=str(file_path), graph=graph) + + return graph + + +# ------------------------------------------------------------------- +# AST traversal +# ------------------------------------------------------------------- + + +def _visit_module(tree: ast.Module, *, file_path: str, graph: VariableGraph) -> None: + """Register each Variable subclass and walk its formula methods.""" + for node in tree.body: + if not isinstance(node, ast.ClassDef): + continue + if not _class_inherits_variable(node): + continue + var_name = node.name + graph.add_variable(var_name, file_path=file_path) + for child in node.body: + if isinstance(child, ast.FunctionDef) and _is_formula_method(child): + for dependency in _extract_references(child): + graph.add_edge(dependency=dependency, dependent=var_name) + + +def _class_inherits_variable(cls: ast.ClassDef) -> bool: + """True iff the class's base list contains a ``Variable`` name. + + Matches ``class X(Variable):``. Does not resolve aliased imports + — PolicyEngine's ``from policyengine_us.model_api import *`` + convention keeps the base name literally ``Variable``, which is + what real jurisdictions use and what this check matches. + """ + for base in cls.bases: + if isinstance(base, ast.Name) and base.id == "Variable": + return True + return False + + +def _is_formula_method(func: ast.FunctionDef) -> bool: + """Return True for ``formula`` and ``formula_YYYY`` methods.""" + return func.name == "formula" or func.name.startswith("formula_") + + +# ------------------------------------------------------------------- +# Reference extraction from a formula body +# ------------------------------------------------------------------- + + +def _extract_references(func: ast.FunctionDef) -> Iterator[str]: + """Yield every variable name referenced in the function body.""" + for node in ast.walk(func): + if not isinstance(node, ast.Call): + continue + # Pattern 1: ("", ) + entity_ref = _entity_call_to_variable(node) + if entity_ref is not None: + yield entity_ref + continue + # Pattern 2: add(, , ["v1", "v2", ...]) + yield from _add_call_to_variables(node) + + +def _entity_call_to_variable(call: ast.Call) -> str | None: + """Return the variable name if ``call`` is an entity-call pattern. + + The entity has to be a bare Name (not an attribute access), so + calls like ``some.object.person("x", period)`` are deliberately + not matched. First positional arg must be a string literal. + """ + if not isinstance(call.func, ast.Name): + return None + if call.func.id not in _ENTITY_CALL_NAMES: + return None + if not call.args: + return None + first = call.args[0] + if isinstance(first, ast.Constant) and isinstance(first.value, str): + return first.value + return None + + +def _add_call_to_variables(call: ast.Call) -> Iterator[str]: + """Yield variable names from an ``add(entity, period, [list])`` call. + + Matches the common helper. The third positional arg must be a + ``list`` literal of string literals. Anything dynamically built + is skipped. + """ + if not isinstance(call.func, ast.Name): + return + if call.func.id not in {"add", "aggr"}: + return + if len(call.args) < 3: + return + names_arg = call.args[2] + if not isinstance(names_arg, (ast.List, ast.Tuple)): + return + for elt in names_arg.elts: + if isinstance(elt, ast.Constant) and isinstance(elt.value, str): + yield elt.value diff --git a/src/policyengine/graph/graph.py b/src/policyengine/graph/graph.py new file mode 100644 index 00000000..2f5d516e --- /dev/null +++ b/src/policyengine/graph/graph.py @@ -0,0 +1,130 @@ +"""NetworkX-backed variable dependency graph. + +Separated from the extractor so the data structure is easy to test +independently, easy to serialize/deserialize, and easy to enrich with +additional edge types (parameter reads, cross-jurisdiction links) in +later versions. +""" + +from __future__ import annotations + +from typing import Iterable, Optional + +try: + import networkx as nx +except ImportError as exc: # pragma: no cover - trivial guard + raise ImportError( + "policyengine.graph requires networkx. " + "Install the optional extra: pip install 'policyengine[graph]'." + ) from exc + + +class VariableGraph: + """Directed graph of PolicyEngine variable dependencies. + + Nodes are variable names (strings). Edges run from a *dependency* + to a *dependent*: ``A -> B`` means "computing B reads A". With + this orientation, ``impact(A)`` is the set of downstream nodes + reachable from A, and ``deps(B)`` is the set of upstream nodes + that reach into B. + + The constructor accepts an optional pre-built graph for testing + and deserialization; normal callers will get instances via the + extractor. + """ + + def __init__(self, digraph: Optional[nx.DiGraph] = None) -> None: + self._g = digraph if digraph is not None else nx.DiGraph() + + # ------------------------------------------------------------------ + # Construction helpers (used by the extractor) + # ------------------------------------------------------------------ + + def add_variable(self, name: str, file_path: Optional[str] = None) -> None: + """Register a variable as a node. Safe to call repeatedly.""" + if name in self._g: + if file_path and "file_path" not in self._g.nodes[name]: + self._g.nodes[name]["file_path"] = file_path + return + self._g.add_node(name, file_path=file_path) + + def add_edge(self, dependency: str, dependent: str) -> None: + """Record that ``dependent`` reads ``dependency`` in a formula.""" + # Auto-register the dependency node if it wasn't declared yet; + # this is common when a formula references a variable defined + # in a file the extractor hasn't reached yet, or a variable + # whose class lives in a different subpackage. + if dependency not in self._g: + self._g.add_node(dependency, file_path=None) + if dependent not in self._g: + self._g.add_node(dependent, file_path=None) + self._g.add_edge(dependency, dependent) + + # ------------------------------------------------------------------ + # Query surface + # ------------------------------------------------------------------ + + def has_variable(self, name: str) -> bool: + """True iff ``name`` was registered as an explicitly-defined variable. + + Nodes that only exist because some formula *references* them — + but whose class definition was never seen — are excluded. + """ + if name not in self._g: + return False + return self._g.nodes[name].get("file_path") is not None + + def deps(self, name: str) -> Iterable[str]: + """Return variables that ``name``'s formula reads directly. + + Order follows networkx's insertion order, so the caller can + expect a deterministic sequence for a given extraction run. + """ + if name not in self._g: + return iter(()) + return list(self._g.predecessors(name)) + + def impact(self, name: str) -> Iterable[str]: + """Return variables that transitively depend on ``name``. + + Equivalent to the descendants set in the graph's natural + orientation (edges run dep → dependent). Excludes ``name`` + itself. Empty for leaf variables that nothing reads. + """ + if name not in self._g: + return iter(()) + return list(nx.descendants(self._g, name)) + + def path(self, src: str, dst: str) -> Optional[list[str]]: + """Return a shortest dependency chain from ``src`` to ``dst``. + + Returns the node list including both endpoints, or ``None`` if + no such path exists. + """ + if src not in self._g or dst not in self._g: + return None + try: + return nx.shortest_path(self._g, src, dst) + except nx.NetworkXNoPath: + return None + + # ------------------------------------------------------------------ + # Introspection for callers that want the raw structure + # ------------------------------------------------------------------ + + @property + def nx_graph(self) -> nx.DiGraph: + """The underlying NetworkX DiGraph (read-only-by-convention).""" + return self._g + + def __contains__(self, name: str) -> bool: + return name in self._g + + def __len__(self) -> int: + return self._g.number_of_nodes() + + def __repr__(self) -> str: + return ( + f"VariableGraph({self._g.number_of_nodes()} variables, " + f"{self._g.number_of_edges()} edges)" + ) diff --git a/tests/test_graph/__init__.py b/tests/test_graph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_graph/conftest.py b/tests/test_graph/conftest.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_graph/test_extractor.py b/tests/test_graph/test_extractor.py new file mode 100644 index 00000000..81445caf --- /dev/null +++ b/tests/test_graph/test_extractor.py @@ -0,0 +1,314 @@ +"""Tests for the variable-graph extractor. + +The extractor walks PolicyEngine-style Variable source trees and +builds a dependency graph from formula-body references. Two reference +patterns are recognized in MVP: + +1. ``("", )`` — direct call on an entity instance + inside a formula method. ```` matches a known set: + ``person``, ``tax_unit``, ``spm_unit``, ``household``, ``family``, + ``marital_unit``, ``benunit``. +2. ``add(, , ["v1", "v2"])`` — helper that sums a list + of variable values. Each string in the list is extracted. + +Tests run against a self-contained fixture tree under the test file's +own tmp directory — no dependency on an installed country model — so +behavior is deterministic and the tests pin the extraction algorithm +rather than PolicyEngine's evolving source. +""" + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from textwrap import dedent +from types import ModuleType + +import pytest + + +# ``policyengine/__init__.py`` eagerly imports the full country-model +# stack (policyengine-us, policyengine-uk), which makes a normal +# ``from policyengine.graph import ...`` fail in any environment +# where those jurisdictions aren't fully provisioned (missing release +# manifests, unresolved optional deps, etc.). The graph module is +# self-contained (stdlib + networkx only); load it via importlib +# directly so these tests remain environment-agnostic. +def _load_graph_module() -> ModuleType: + if "policyengine.graph" in sys.modules and hasattr( + sys.modules["policyengine.graph"], "extract_from_path" + ): + return sys.modules["policyengine.graph"] + + graph_dir = Path(__file__).resolve().parents[2] / "src" / "policyengine" / "graph" + + if "policyengine" not in sys.modules: + fake_pkg = ModuleType("policyengine") + fake_pkg.__path__ = [str(graph_dir.parent)] + sys.modules["policyengine"] = fake_pkg + if "policyengine.graph" not in sys.modules or not hasattr( + sys.modules["policyengine.graph"], "__path__" + ): + fake_subpkg = ModuleType("policyengine.graph") + fake_subpkg.__path__ = [str(graph_dir)] + sys.modules["policyengine.graph"] = fake_subpkg + + for submod, filename in [ + ("policyengine.graph.graph", "graph.py"), + ("policyengine.graph.extractor", "extractor.py"), + ]: + if submod in sys.modules: + continue + spec = importlib.util.spec_from_file_location(submod, graph_dir / filename) + module = importlib.util.module_from_spec(spec) + sys.modules[submod] = module + spec.loader.exec_module(module) # type: ignore[union-attr] + + graph_mod = sys.modules["policyengine.graph"] + graph_mod.extract_from_path = sys.modules[ + "policyengine.graph.extractor" + ].extract_from_path + graph_mod.VariableGraph = sys.modules["policyengine.graph.graph"].VariableGraph + return graph_mod + + +_graph = _load_graph_module() +extract_from_path = _graph.extract_from_path +VariableGraph = _graph.VariableGraph + + +def _write_variable( + root: Path, var_name: str, formula_body: str, entity: str = "tax_unit" +) -> None: + """Write a Variable subclass file mimicking policyengine-us style.""" + root.mkdir(parents=True, exist_ok=True) + (root / f"{var_name}.py").write_text( + dedent(f'''\ + from policyengine_us.model_api import * + + + class {var_name}(Variable): + value_type = float + entity = TaxUnit + label = "{var_name.replace("_", " ").title()}" + definition_period = YEAR + + def formula({entity}, period, parameters): + {formula_body} + ''') + ) + + +class TestDirectEntityReference: + """Pattern 1: ``entity("", period)`` produces an edge.""" + + def test_single_direct_reference(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + _write_variable( + root, + "adjusted_gross_income", + 'return tax_unit("gross_income", period) - tax_unit("above_the_line_deductions", period)', + ) + _write_variable(root, "gross_income", "return 0") + _write_variable(root, "above_the_line_deductions", "return 0") + + graph = extract_from_path(root) + + assert graph.has_variable("adjusted_gross_income") + deps = set(graph.deps("adjusted_gross_income")) + assert deps == {"gross_income", "above_the_line_deductions"} + + def test_nonmatching_string_is_ignored(self, tmp_path: Path) -> None: + """String literals unrelated to an entity call are ignored. + + Only a string as the first arg of a matching + ``("", period)`` call becomes an edge; string + literals used as argument to ``print`` or bound to a local + name are not misinterpreted as variable references. + """ + root = tmp_path / "variables" + root.mkdir(parents=True, exist_ok=True) + (root / "refundable_credit.py").write_text( + dedent("""\ + from policyengine_us.model_api import * + + + class refundable_credit(Variable): + value_type = float + entity = TaxUnit + label = "Refundable credit" + definition_period = YEAR + + def formula(tax_unit, period, parameters): + note = "not a variable reference" + return tax_unit("gross_income", period) + """) + ) + _write_variable(root, "gross_income", "return 0") + graph = extract_from_path(root) + assert set(graph.deps("refundable_credit")) == {"gross_income"} + + +class TestAddHelperReference: + """Pattern 2: ``add(entity, period, [...])`` emits one edge per list item.""" + + def test_add_helper_list(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + _write_variable( + root, + "total_income", + 'return add(tax_unit, period, ["wages", "self_employment_income", "interest"])', + ) + _write_variable(root, "wages", "return 0") + _write_variable(root, "self_employment_income", "return 0") + _write_variable(root, "interest", "return 0") + graph = extract_from_path(root) + assert set(graph.deps("total_income")) == { + "wages", + "self_employment_income", + "interest", + } + + +class TestImpactAnalysis: + """``impact(var)`` returns variables that depend on ``var`` transitively.""" + + def test_transitive_upstream(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + _write_variable(root, "wages", "return 0") + _write_variable( + root, + "gross_income", + 'return add(tax_unit, period, ["wages"])', + ) + _write_variable( + root, + "adjusted_gross_income", + 'return tax_unit("gross_income", period)', + ) + _write_variable( + root, + "taxable_income", + 'return tax_unit("adjusted_gross_income", period)', + ) + _write_variable( + root, + "federal_income_tax", + 'return tax_unit("taxable_income", period)', + ) + graph = extract_from_path(root) + + # wages is read by gross_income → adjusted_gross_income → + # taxable_income → federal_income_tax (depth 4). + impact = set(graph.impact("wages")) + assert impact == { + "gross_income", + "adjusted_gross_income", + "taxable_income", + "federal_income_tax", + } + + def test_leaf_variable_has_empty_impact(self, tmp_path: Path) -> None: + """A variable that nothing reads has an empty impact set.""" + + root = tmp_path / "variables" + _write_variable( + root, + "federal_income_tax", + 'return tax_unit("adjusted_gross_income", period)', + ) + _write_variable(root, "adjusted_gross_income", "return 0") + graph = extract_from_path(root) + assert list(graph.impact("federal_income_tax")) == [] + + +class TestMultipleFormulas: + """Year-specific ``formula_YYYY`` methods contribute edges too.""" + + def test_year_specific_formula_contributes_edges(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + (root / "ctc.py").parent.mkdir(parents=True, exist_ok=True) + (root / "ctc.py").write_text( + dedent("""\ + from policyengine_us.model_api import * + + + class ctc(Variable): + value_type = float + entity = TaxUnit + label = "Child Tax Credit" + definition_period = YEAR + + def formula_2020(tax_unit, period, parameters): + return tax_unit("ctc_base_2020", period) + + def formula_2023(tax_unit, period, parameters): + return tax_unit("ctc_base_2023", period) + """) + ) + _write_variable(root, "ctc_base_2020", "return 0") + _write_variable(root, "ctc_base_2023", "return 0") + + graph = extract_from_path(root) + assert set(graph.deps("ctc")) == {"ctc_base_2020", "ctc_base_2023"} + + +class TestPath: + """``path(src, dst)`` returns a dependency chain if one exists.""" + + def test_path_two_hops(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + _write_variable(root, "wages", "return 0") + _write_variable(root, "gross_income", 'return tax_unit("wages", period)') + _write_variable( + root, + "adjusted_gross_income", + 'return tax_unit("gross_income", period)', + ) + + graph = extract_from_path(root) + assert graph.path("wages", "adjusted_gross_income") == [ + "wages", + "gross_income", + "adjusted_gross_income", + ] + + def test_path_returns_none_if_unreachable(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + _write_variable(root, "island_a", "return 0") + _write_variable(root, "island_b", "return 0") + graph = extract_from_path(root) + assert graph.path("island_a", "island_b") is None + + +class TestRequiresVariableSubclass: + """Only classes whose base class list contains ``Variable`` are scanned. + + Helper modules (model_api, utils) should not be mistaken for + Variable definitions even if they have method bodies that call + entity-style functions. + """ + + def test_non_variable_classes_are_ignored(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + root.mkdir(parents=True, exist_ok=True) + # Looks like a variable body but the class is not a Variable. + (root / "helper.py").write_text( + dedent("""\ + class NotAVariable: + def some_method(tax_unit, period, parameters): + return tax_unit("some_variable", period) + """) + ) + graph = extract_from_path(root) + assert not graph.has_variable("NotAVariable") + # And no edge to "some_variable" should exist from a phantom source. + assert list(graph.impact("some_variable")) == [] From dcce3ff84e2bf3361cf5f9355c1de15f83c35956 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Mon, 20 Apr 2026 10:29:23 -0400 Subject: [PATCH 2/2] Fix ruff lint and format --- docs/_generator/build_reference.py | 17 +++++++++++------ src/policyengine/graph/extractor.py | 3 +-- tests/test_graph/test_extractor.py | 2 -- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/_generator/build_reference.py b/docs/_generator/build_reference.py index 4b360622..490420cd 100644 --- a/docs/_generator/build_reference.py +++ b/docs/_generator/build_reference.py @@ -85,7 +85,9 @@ class VariableRecord: tree_path: tuple[str, ...] -def _tree_path_from_source(source_file: Path | None, package_root: Path) -> tuple[str, ...]: +def _tree_path_from_source( + source_file: Path | None, package_root: Path +) -> tuple[str, ...]: if source_file is None: return ("_ungrouped",) try: @@ -131,7 +133,9 @@ def _variable_records(country: str) -> Iterable[VariableRecord]: value_type_name = ( value_type.__name__ if isinstance(value_type, type) - else str(value_type) if value_type is not None else None + else str(value_type) + if value_type is not None + else None ) defined_for = getattr(variable, "defined_for", None) defined_for_name = ( @@ -185,7 +189,10 @@ def _render_variable_page(record: VariableRecord, country: str) -> str: ("Entity", f"`{record.entity}`" if record.entity else "—"), ("Value type", f"`{record.value_type}`" if record.value_type else "—"), ("Unit", f"`{record.unit}`" if record.unit else "—"), - ("Period", f"`{record.definition_period}`" if record.definition_period else "—"), + ( + "Period", + f"`{record.definition_period}`" if record.definition_period else "—", + ), ( "Defined for", f"`{record.defined_for}`" if record.defined_for else "—", @@ -230,9 +237,7 @@ def _render_variable_page(record: VariableRecord, country: str) -> str: if record.source_file: try: - repo_rel = record.source_file.relative_to( - record.source_file.parents[5] - ) + repo_rel = record.source_file.relative_to(record.source_file.parents[5]) except (ValueError, IndexError): repo_rel = record.source_file.name lines.append("## Source") diff --git a/src/policyengine/graph/extractor.py b/src/policyengine/graph/extractor.py index 1af61a7b..39f278cb 100644 --- a/src/policyengine/graph/extractor.py +++ b/src/policyengine/graph/extractor.py @@ -34,11 +34,10 @@ import ast import os from pathlib import Path -from typing import Iterable, Iterator, Sequence, Union +from typing import Iterable, Iterator, Union from policyengine.graph.graph import VariableGraph - # Names of entity instances as they appear as method parameters in # Variable formulas. Any ``Call`` whose ``func`` is a bare ``Name`` # matching one of these AND whose first arg is a string literal is diff --git a/tests/test_graph/test_extractor.py b/tests/test_graph/test_extractor.py index 81445caf..91e2a840 100644 --- a/tests/test_graph/test_extractor.py +++ b/tests/test_graph/test_extractor.py @@ -25,8 +25,6 @@ from textwrap import dedent from types import ModuleType -import pytest - # ``policyengine/__init__.py`` eagerly imports the full country-model # stack (policyengine-us, policyengine-uk), which makes a normal