From e3c05554c481797c8be8f0d980ebbfd7414d6948 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 14 Apr 2026 22:37:36 +0200 Subject: [PATCH 1/3] Add traceability and scope fingerprinting --- modal_app/local_area.py | 137 +++++++--- modal_app/pipeline.py | 27 +- .../calibration/local_h5/__init__.py | 2 +- .../calibration/local_h5/fingerprinting.py | 247 ++++++++++++++++++ .../calibration/publish_local_area.py | 72 ++--- .../fixtures/test_local_h5_fingerprinting.py | 164 ++++++++++++ .../test_local_h5_fingerprinting.py | 80 ++++++ tests/unit/fixtures/test_modal_local_area.py | 16 ++ tests/unit/test_pipeline.py | 35 +++ 9 files changed, 696 insertions(+), 84 deletions(-) create mode 100644 policyengine_us_data/calibration/local_h5/fingerprinting.py create mode 100644 tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py create mode 100644 tests/unit/calibration/test_local_h5_fingerprinting.py diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 0beafee5c..e5045582d 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -29,6 +29,10 @@ from modal_app.images import cpu_image as image # noqa: E402 from modal_app.resilience import reconcile_run_dir_fingerprint # noqa: E402 +from policyengine_us_data.calibration.local_h5.fingerprinting import ( # noqa: E402 + FingerprintingService, + PublishingInputBundle, +) from policyengine_us_data.calibration.local_h5.partitioning import ( # noqa: E402 partition_weighted_work_items, ) @@ -311,6 +315,65 @@ def get_version() -> str: return pyproject["project"]["version"] +def _build_publishing_input_bundle( + *, + weights_path: Path, + dataset_path: Path, + db_path: Path | None, + geography_path: Path | None, + calibration_package_path: Path | None, + run_config_path: Path | None, + run_id: str, + version: str, + n_clones: int | None, + seed: int, + legacy_blocks_path: Path | None = None, +) -> PublishingInputBundle: + """Build the normalized coordinator input bundle for one publish scope.""" + + return PublishingInputBundle( + weights_path=weights_path, + source_dataset_path=dataset_path, + target_db_path=db_path, + exact_geography_path=geography_path, + calibration_package_path=calibration_package_path, + run_config_path=run_config_path, + run_id=run_id, + version=version, + n_clones=n_clones, + seed=seed, + legacy_blocks_path=legacy_blocks_path, + ) + + +def _resolve_scope_fingerprint( + *, + inputs: PublishingInputBundle, + scope: str, + expected_fingerprint: str = "", +) -> str: + """Compute the scope fingerprint while preserving pinned resume values.""" + + service = FingerprintingService() + traceability = service.build_traceability(inputs=inputs, scope=scope) + computed_fingerprint = service.compute_scope_fingerprint(traceability) + if expected_fingerprint: + if expected_fingerprint != computed_fingerprint: + print( + "WARNING: Pinned fingerprint differs from current " + f"{scope} scope fingerprint. " + "Preserving pinned value for backward-compatible resume.\n" + f" Pinned: {expected_fingerprint}\n" + f" Current: {computed_fingerprint}" + ) + else: + print( + f"Using pinned fingerprint from pipeline: {expected_fingerprint}" + ) + return expected_fingerprint + return computed_fingerprint + + def partition_work( work_items: List[Dict], num_workers: int, @@ -836,45 +899,26 @@ def coordinate_publish( validate = False # Fingerprint-based cache invalidation - if expected_fingerprint: - fingerprint = expected_fingerprint - print(f"Using pinned fingerprint from pipeline: {fingerprint}") - else: - geography_path_expr = ( - f'Path("{geography_path}")' if geography_path.exists() else "None" - ) - package_path_expr = ( - f'Path("{calibration_package_path}")' - if calibration_package_path.exists() - else "None" - ) - fp_result = subprocess.run( - _python_cmd( - "-c", - f""" -from pathlib import Path -from policyengine_us_data.calibration.publish_local_area import ( - compute_input_fingerprint, -) -print( - compute_input_fingerprint( - Path("{weights_path}"), - Path("{dataset_path}"), - {n_clones}, + fingerprint_inputs = _build_publishing_input_bundle( + weights_path=weights_path, + dataset_path=dataset_path, + db_path=db_path, + geography_path=geography_path, + calibration_package_path=( + calibration_package_path if calibration_package_path.exists() else None + ), + run_config_path=config_json_path if config_json_path.exists() else None, + run_id=run_id, + version=version, + n_clones=n_clones, seed=42, - geography_path={geography_path_expr}, - calibration_package_path={package_path_expr}, + legacy_blocks_path=artifacts / "stacked_blocks.npy", + ) + fingerprint = _resolve_scope_fingerprint( + inputs=fingerprint_inputs, + scope="regional", + expected_fingerprint=expected_fingerprint, ) -) -""", - ), - capture_output=True, - text=True, - env=os.environ.copy(), - ) - if fp_result.returncode != 0: - raise RuntimeError(f"Failed to compute fingerprint: {fp_result.stderr}") - fingerprint = fp_result.stdout.strip() reconcile_action = reconcile_run_dir_fingerprint(run_dir, fingerprint) if reconcile_action == "resume": print(f"Inputs unchanged ({fingerprint}), resuming...") @@ -1064,6 +1108,7 @@ def coordinate_national_publish( n_clones: int = 430, validate: bool = True, run_id: str = "", + expected_fingerprint: str = "", ) -> Dict: """Build and upload a national US.h5 from national weights.""" setup_gcp_credentials() @@ -1123,6 +1168,23 @@ def coordinate_national_publish( "geography_assignment.npz": "national_geography_assignment.npz", }, ) + fingerprint_inputs = _build_publishing_input_bundle( + weights_path=weights_path, + dataset_path=dataset_path, + db_path=db_path, + geography_path=geography_path, + calibration_package_path=None, + run_config_path=config_json_path if config_json_path.exists() else None, + run_id=run_id, + version=version, + n_clones=n_clones, + seed=42, + ) + fingerprint = _resolve_scope_fingerprint( + inputs=fingerprint_inputs, + scope="national", + expected_fingerprint=expected_fingerprint, + ) run_dir = staging_dir / run_id run_dir.mkdir(parents=True, exist_ok=True) @@ -1224,6 +1286,7 @@ def coordinate_national_publish( f"{version}. Run main_national_promote to publish." ), "run_id": run_id, + "fingerprint": fingerprint, "national_validation": national_validation_output, } diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index c02d6f10e..02b79d983 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -109,12 +109,26 @@ class RunMetadata: error: Optional[str] = None resume_history: list = field(default_factory=list) fingerprint: Optional[str] = None + regional_fingerprint: Optional[str] = None + national_fingerprint: Optional[str] = None + + def __post_init__(self) -> None: + if self.regional_fingerprint is None and self.fingerprint is not None: + self.regional_fingerprint = self.fingerprint + if self.fingerprint is None and self.regional_fingerprint is not None: + self.fingerprint = self.regional_fingerprint def to_dict(self) -> dict: - return asdict(self) + data = asdict(self) + if data.get("fingerprint") is None and data.get("regional_fingerprint") is not None: + data["fingerprint"] = data["regional_fingerprint"] + return data @classmethod def from_dict(cls, data: dict) -> "RunMetadata": + data = dict(data) + if data.get("regional_fingerprint") is None and data.get("fingerprint") is not None: + data["regional_fingerprint"] = data["fingerprint"] return cls(**data) @@ -981,7 +995,9 @@ def run_pipeline( n_clones=n_clones, validate=True, run_id=run_id, - expected_fingerprint=meta.fingerprint or "", + expected_fingerprint=( + meta.regional_fingerprint or meta.fingerprint or "" + ), ) print(f" → coordinate_publish fc: {regional_h5_handle.object_id}") @@ -993,6 +1009,7 @@ def run_pipeline( n_clones=n_clones, validate=True, run_id=run_id, + expected_fingerprint=meta.national_fingerprint or "", ) print( f" → coordinate_national_publish fc: {national_h5_handle.object_id}" @@ -1017,6 +1034,7 @@ def run_pipeline( if isinstance(regional_h5_result, dict) and regional_h5_result.get( "fingerprint" ): + meta.regional_fingerprint = regional_h5_result["fingerprint"] meta.fingerprint = regional_h5_result["fingerprint"] write_run_meta(meta, pipeline_volume) @@ -1030,6 +1048,11 @@ def run_pipeline( else national_h5_result ) print(f" National H5: {national_msg}") + if isinstance(national_h5_result, dict) and national_h5_result.get( + "fingerprint" + ): + meta.national_fingerprint = national_h5_result["fingerprint"] + write_run_meta(meta, pipeline_volume) # ── Aggregate validation results ── _write_validation_diagnostics( diff --git a/policyengine_us_data/calibration/local_h5/__init__.py b/policyengine_us_data/calibration/local_h5/__init__.py index f69663eb0..96ec7258f 100644 --- a/policyengine_us_data/calibration/local_h5/__init__.py +++ b/policyengine_us_data/calibration/local_h5/__init__.py @@ -3,5 +3,5 @@ Modules in this package should land only when they become active runtime seams rather than speculative placeholders. The current early slices introduce ``partitioning.py``, ``requests.py``, ``area_catalog.py``, -and ``geography_loader.py``. +``fingerprinting.py``, and ``geography_loader.py``. """ diff --git a/policyengine_us_data/calibration/local_h5/fingerprinting.py b/policyengine_us_data/calibration/local_h5/fingerprinting.py new file mode 100644 index 000000000..6bff37af0 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/fingerprinting.py @@ -0,0 +1,247 @@ +"""Coordinator-owned provenance and resumability logic for local H5 publication.""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal, Mapping + +from .geography_loader import CalibrationGeographyLoader + +FingerprintScope = Literal["regional", "national"] + + +@dataclass(frozen=True) +class PublishingInputBundle: + """File-system and run metadata needed to publish one H5 scope.""" + + weights_path: Path + source_dataset_path: Path + target_db_path: Path | None + exact_geography_path: Path | None + calibration_package_path: Path | None + run_config_path: Path | None + run_id: str + version: str + n_clones: int | None + seed: int + legacy_blocks_path: Path | None = None + + +@dataclass(frozen=True) +class ArtifactIdentity: + """Stable identity for one input artifact used by traceability and resume.""" + + logical_name: str + path: Path | None + sha256: str | None + size_bytes: int | None = None + metadata: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class TraceabilityBundle: + """Full provenance record for one publish scope.""" + + scope: FingerprintScope + weights: ArtifactIdentity + source_dataset: ArtifactIdentity + exact_geography: ArtifactIdentity | None = None + target_db: ArtifactIdentity | None = None + calibration_package: ArtifactIdentity | None = None + run_config: ArtifactIdentity | None = None + code_version: Mapping[str, Any] = field(default_factory=dict) + model_build: Mapping[str, Any] = field(default_factory=dict) + metadata: Mapping[str, Any] = field(default_factory=dict) + + def resumability_material(self) -> Mapping[str, Any]: + """Return the normalized subset that controls staged-output validity.""" + + geography_sha = None + if self.exact_geography is not None: + geography_sha = self.exact_geography.metadata.get("canonical_sha256") + if geography_sha is None: + geography_sha = self.exact_geography.sha256 + + return { + "scope": self.scope, + "weights_sha256": self.weights.sha256, + "source_dataset_sha256": self.source_dataset.sha256, + "exact_geography_sha256": geography_sha, + "target_db_sha256": ( + self.target_db.sha256 if self.target_db is not None else None + ), + "n_clones": self.metadata.get("n_clones"), + "seed": self.metadata.get("seed"), + "policyengine_us_locked_version": self.model_build.get("locked_version"), + "policyengine_us_git_commit": self.model_build.get("git_commit"), + } + + +class FingerprintingService: + """Build traceability bundles and derive scope fingerprints from them.""" + + def __init__( + self, + *, + geography_loader: CalibrationGeographyLoader | None = None, + ) -> None: + self._geography_loader = geography_loader or CalibrationGeographyLoader() + + def build_traceability( + self, + *, + inputs: PublishingInputBundle, + scope: FingerprintScope, + ) -> TraceabilityBundle: + """Build a traceability bundle from current publish inputs.""" + + run_config_payload = self._load_json(inputs.run_config_path) + return TraceabilityBundle( + scope=scope, + weights=self._build_artifact_identity("weights", inputs.weights_path), + source_dataset=self._build_artifact_identity( + "source_dataset", + inputs.source_dataset_path, + ), + exact_geography=self._build_geography_identity(inputs), + target_db=self._build_optional_artifact_identity( + "target_db", + inputs.target_db_path, + ), + calibration_package=self._build_optional_artifact_identity( + "calibration_package", + inputs.calibration_package_path, + ), + run_config=self._build_optional_artifact_identity( + "run_config", + inputs.run_config_path, + ), + code_version=self._extract_code_version(run_config_payload), + model_build=self._extract_model_build(run_config_payload), + metadata={ + "run_id": inputs.run_id, + "version": inputs.version, + "n_clones": inputs.n_clones, + "seed": inputs.seed, + }, + ) + + def compute_scope_fingerprint(self, traceability: TraceabilityBundle) -> str: + """Hash normalized resumability material into a short scope fingerprint.""" + + payload = json.dumps( + traceability.resumability_material(), + sort_keys=True, + separators=(",", ":"), + ).encode() + return hashlib.sha256(payload).hexdigest()[:16] + + def _build_artifact_identity( + self, + logical_name: str, + path: Path, + *, + metadata: Mapping[str, Any] | None = None, + ) -> ArtifactIdentity: + actual_path = Path(path) + if not actual_path.exists(): + raise FileNotFoundError(f"Expected {logical_name} artifact at {actual_path}") + return ArtifactIdentity( + logical_name=logical_name, + path=actual_path, + sha256=self._sha256_file(actual_path), + size_bytes=actual_path.stat().st_size, + metadata=dict(metadata or {}), + ) + + def _build_optional_artifact_identity( + self, + logical_name: str, + path: Path | None, + ) -> ArtifactIdentity | None: + if path is None: + return None + actual_path = Path(path) + if not actual_path.exists(): + return None + return self._build_artifact_identity(logical_name, actual_path) + + def _build_geography_identity( + self, + inputs: PublishingInputBundle, + ) -> ArtifactIdentity | None: + resolved = self._geography_loader.resolve_source( + weights_path=inputs.weights_path, + geography_path=inputs.exact_geography_path, + blocks_path=inputs.legacy_blocks_path, + calibration_package_path=inputs.calibration_package_path, + ) + if resolved is None: + return None + + metadata = { + "source_kind": resolved.kind, + "canonical_sha256": self._geography_loader.compute_canonical_checksum( + weights_path=inputs.weights_path, + n_records=self._infer_n_records(inputs.source_dataset_path), + n_clones=inputs.n_clones, + geography_path=inputs.exact_geography_path, + blocks_path=inputs.legacy_blocks_path, + calibration_package_path=inputs.calibration_package_path, + ), + } + return self._build_artifact_identity( + "exact_geography", + resolved.path, + metadata=metadata, + ) + + def _extract_code_version(self, run_config_payload: Mapping[str, Any]) -> dict[str, Any]: + return { + "git_commit": run_config_payload.get("git_commit"), + "git_branch": run_config_payload.get("git_branch"), + "git_dirty": run_config_payload.get("git_dirty"), + } + + def _extract_model_build(self, run_config_payload: Mapping[str, Any]) -> dict[str, Any]: + return { + "locked_version": run_config_payload.get("package_version"), + "git_commit": run_config_payload.get("git_commit"), + } + + def _load_json(self, path: Path | None) -> Mapping[str, Any]: + if path is None: + return {} + actual_path = Path(path) + if not actual_path.exists(): + return {} + with open(actual_path) as handle: + return json.load(handle) + + def _sha256_file(self, path: Path) -> str: + digest = hashlib.sha256() + with open(path, "rb") as handle: + for chunk in iter(lambda: handle.read(1 << 20), b""): + digest.update(chunk) + return f"sha256:{digest.hexdigest()}" + + def _infer_n_records(self, source_dataset_path: Path) -> int: + import h5py + + with h5py.File(source_dataset_path, "r") as handle: + if "person" not in handle: + raise ValueError( + f"Unable to infer n_records from {source_dataset_path}: " + "missing 'person' entity" + ) + person_group = handle["person"] + first_dataset_name = next(iter(person_group.keys()), None) + if first_dataset_name is None: + raise ValueError( + f"Unable to infer n_records from {source_dataset_path}: " + "'person' entity is empty" + ) + return int(len(person_group[first_dataset_name])) diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index b1946a8f3..785fbafc8 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -11,12 +11,15 @@ import json import shutil - import numpy as np from pathlib import Path from typing import List, Optional from policyengine_us import Microsimulation +from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintingService, + PublishingInputBundle, +) from policyengine_us_data.calibration.local_h5.geography_loader import ( CalibrationGeographyLoader, ) @@ -48,8 +51,6 @@ META_FILE = WORK_DIR / "checkpoint_meta.json" - - def compute_input_fingerprint( weights_path: Path, dataset_path: Path, @@ -57,50 +58,33 @@ def compute_input_fingerprint( seed: int = 42, geography_path: Optional[Path] = None, blocks_path: Optional[Path] = None, + target_db_path: Optional[Path] = None, + run_config_path: Optional[Path] = None, calibration_package_path: Optional[Path] = None, + scope: str = "regional", ) -> str: - import hashlib - - def _update_hash_from_file(h: "hashlib._Hash", path: Path) -> None: - with open(path, "rb") as f: - while chunk := f.read(8192): - h.update(chunk) - - def _infer_n_records() -> int: - if n_clones is not None: - weights = np.load(weights_path, mmap_mode="r") - if len(weights) % n_clones == 0: - return len(weights) // n_clones - sim = Microsimulation(dataset=str(dataset_path)) - return len(sim.calculate("household_id", map_to="household").values) - - loader = CalibrationGeographyLoader() - h = hashlib.sha256() - for p in [weights_path, dataset_path]: - _update_hash_from_file(h, p) - - resolved = loader.resolve_source( - weights_path=weights_path, - geography_path=geography_path, - blocks_path=blocks_path, - calibration_package_path=calibration_package_path, + service = FingerprintingService() + inputs = PublishingInputBundle( + weights_path=Path(weights_path), + source_dataset_path=Path(dataset_path), + target_db_path=Path(target_db_path) if target_db_path is not None else None, + exact_geography_path=( + Path(geography_path) if geography_path is not None else None + ), + calibration_package_path=( + Path(calibration_package_path) + if calibration_package_path is not None + else None + ), + run_config_path=Path(run_config_path) if run_config_path is not None else None, + run_id="", + version="", + n_clones=n_clones, + seed=seed, + legacy_blocks_path=Path(blocks_path) if blocks_path is not None else None, ) - if resolved is not None: - n_records = _infer_n_records() - h.update(f"geography_source:{resolved.kind}".encode()) - h.update( - loader.compute_canonical_checksum( - weights_path=weights_path, - n_records=n_records, - n_clones=n_clones, - geography_path=geography_path, - blocks_path=blocks_path, - calibration_package_path=calibration_package_path, - ).encode() - ) - else: - h.update(f"legacy_regeneration:{n_clones}:{seed}".encode()) - return h.hexdigest()[:16] + traceability = service.build_traceability(inputs=inputs, scope=scope) + return service.compute_scope_fingerprint(traceability) def load_calibration_geography( diff --git a/tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py b/tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py new file mode 100644 index 000000000..d8bba2148 --- /dev/null +++ b/tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py @@ -0,0 +1,164 @@ +"""Fixture helpers for ``test_local_h5_fingerprinting.py``.""" + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType + +import h5py +import numpy as np + +from tests.unit.calibration.fixtures.test_local_h5_geography_loader import ( + write_saved_geography, +) + +__test__ = False + + +def _ensure_package(name: str, path: Path) -> None: + """Register a synthetic package so relative imports resolve locally.""" + + package = sys.modules.get(name) + if package is None: + package = ModuleType(name) + package.__path__ = [str(path)] + sys.modules[name] = package + return + package.__path__ = [str(path)] + + +def _load_module(name: str, path: Path): + """Load one module from disk under a specific fully-qualified name.""" + + sys.modules.pop(name, None) + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + assert spec is not None + assert spec.loader is not None + sys.modules[name] = module + spec.loader.exec_module(module) + return module + + +def load_fingerprinting_exports(): + """Load the local H5 fingerprinting module under a synthetic package name.""" + + repo_root = Path(__file__).resolve().parents[4] + local_h5_root = ( + repo_root + / "policyengine_us_data" + / "calibration" + / "local_h5" + ) + calibration_root = repo_root / "policyengine_us_data" / "calibration" + storage_root = repo_root / "policyengine_us_data" / "storage" + package_name = "local_h5_fingerprinting_fixture" + policyengine_package = "policyengine_us_data" + calibration_package = "policyengine_us_data.calibration" + + for name in list(sys.modules): + if ( + name == package_name + or name.startswith(f"{package_name}.") + or name == policyengine_package + or name.startswith(f"{policyengine_package}.") + ): + sys.modules.pop(name, None) + + _ensure_package(package_name, local_h5_root) + _ensure_package(policyengine_package, repo_root / "policyengine_us_data") + _ensure_package(calibration_package, calibration_root) + _load_module( + "policyengine_us_data.storage", + storage_root / "__init__.py", + ) + _load_module( + "policyengine_us_data.calibration.clone_and_assign", + calibration_root / "clone_and_assign.py", + ) + _load_module( + f"{package_name}.geography_loader", + local_h5_root / "geography_loader.py", + ) + module = _load_module( + f"{package_name}.fingerprinting", + local_h5_root / "fingerprinting.py", + ) + return { + "module": module, + "ArtifactIdentity": module.ArtifactIdentity, + "FingerprintingService": module.FingerprintingService, + "PublishingInputBundle": module.PublishingInputBundle, + "TraceabilityBundle": module.TraceabilityBundle, + } + + +def write_source_dataset(path: Path, *, n_records: int) -> None: + """Write a minimal HDF5 dataset with a ``person`` entity.""" + + with h5py.File(path, "w") as handle: + person = handle.create_group("person") + person.create_dataset("person_id", data=np.arange(n_records, dtype=np.int32)) + + +def write_run_config(path: Path, *, package_version: str = "1.0.0") -> None: + """Write a minimal run-config payload with provenance fields.""" + + payload = { + "git_commit": "deadbeefcafebabe", + "git_branch": "main", + "git_dirty": False, + "package_version": package_version, + } + path.write_text(json.dumps(payload)) + + +def write_artifact_file(path: Path, content: bytes) -> None: + """Write one small binary artifact for traceability tests.""" + + path.write_bytes(content) + + +def make_publishing_inputs( + bundle_cls, + *, + tmp_path: Path, + n_records: int = 2, + n_clones: int = 2, + seed: int = 42, + package_version: str = "1.0.0", +): + """Create a fully-populated publishing input bundle for tests.""" + + tmp_path.mkdir(parents=True, exist_ok=True) + weights_path = tmp_path / "calibration_weights.npy" + dataset_path = tmp_path / "source.h5" + db_path = tmp_path / "policy_data.db" + geography_path = tmp_path / "geography_assignment.npz" + run_config_path = tmp_path / "unified_run_config.json" + + np.save(weights_path, np.array([1.0, 2.0, 3.0])) + write_source_dataset(dataset_path, n_records=n_records) + write_artifact_file(db_path, b"fake-db") + write_saved_geography( + geography_path, + n_records=n_records, + n_clones=n_clones, + ) + write_run_config(run_config_path, package_version=package_version) + + return bundle_cls( + weights_path=weights_path, + source_dataset_path=dataset_path, + target_db_path=db_path, + exact_geography_path=geography_path, + calibration_package_path=None, + run_config_path=run_config_path, + run_id="run-123", + version="1.2.3", + n_clones=n_clones, + seed=seed, + ) diff --git a/tests/unit/calibration/test_local_h5_fingerprinting.py b/tests/unit/calibration/test_local_h5_fingerprinting.py new file mode 100644 index 000000000..180e242f1 --- /dev/null +++ b/tests/unit/calibration/test_local_h5_fingerprinting.py @@ -0,0 +1,80 @@ +from tests.unit.calibration.fixtures.test_local_h5_fingerprinting import ( + load_fingerprinting_exports, + make_publishing_inputs, +) + + +exports = load_fingerprinting_exports() +FingerprintingService = exports["FingerprintingService"] +PublishingInputBundle = exports["PublishingInputBundle"] + + +def test_build_traceability_captures_artifact_identity_and_metadata(tmp_path): + inputs = make_publishing_inputs(PublishingInputBundle, tmp_path=tmp_path) + + service = FingerprintingService() + traceability = service.build_traceability(inputs=inputs, scope="regional") + + assert traceability.scope == "regional" + assert traceability.weights.path == inputs.weights_path + assert traceability.weights.sha256.startswith("sha256:") + assert traceability.source_dataset.sha256.startswith("sha256:") + assert traceability.exact_geography is not None + assert traceability.exact_geography.metadata["source_kind"] == "saved_geography" + assert traceability.exact_geography.metadata["canonical_sha256"].startswith( + "sha256:" + ) + assert traceability.target_db is not None + assert traceability.model_build["locked_version"] == "1.0.0" + assert traceability.metadata["n_clones"] == 2 + assert traceability.metadata["seed"] == 42 + + +def test_scope_fingerprint_differs_between_regional_and_national(tmp_path): + inputs = make_publishing_inputs(PublishingInputBundle, tmp_path=tmp_path) + + service = FingerprintingService() + regional = service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope="regional") + ) + national = service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope="national") + ) + + assert regional != national + + +def test_scope_fingerprint_is_stable_for_identical_inputs(tmp_path): + inputs = make_publishing_inputs(PublishingInputBundle, tmp_path=tmp_path) + + service = FingerprintingService() + first = service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope="regional") + ) + second = service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope="regional") + ) + + assert first == second + + +def test_scope_fingerprint_changes_when_relevant_provenance_changes(tmp_path): + first_inputs = make_publishing_inputs( + PublishingInputBundle, + tmp_path=tmp_path / "first", + ) + second_inputs = make_publishing_inputs( + PublishingInputBundle, + tmp_path=tmp_path / "second", + ) + second_inputs.target_db_path.write_bytes(b"changed-db") + + service = FingerprintingService() + first = service.compute_scope_fingerprint( + service.build_traceability(inputs=first_inputs, scope="regional") + ) + second = service.compute_scope_fingerprint( + service.build_traceability(inputs=second_inputs, scope="regional") + ) + + assert first != second diff --git a/tests/unit/fixtures/test_modal_local_area.py b/tests/unit/fixtures/test_modal_local_area.py index 377e879ae..935da8d6e 100644 --- a/tests/unit/fixtures/test_modal_local_area.py +++ b/tests/unit/fixtures/test_modal_local_area.py @@ -41,6 +41,9 @@ def load_local_area_module(): fake_partitioning = ModuleType( "policyengine_us_data.calibration.local_h5.partitioning" ) + fake_fingerprinting = ModuleType( + "policyengine_us_data.calibration.local_h5.fingerprinting" + ) fake_policyengine.__path__ = [] fake_calibration.__path__ = [] fake_local_h5.__path__ = [] @@ -71,6 +74,16 @@ def decorator(func): fake_resilience = ModuleType("modal_app.resilience") fake_resilience.reconcile_run_dir_fingerprint = lambda *args, **kwargs: None fake_partitioning.partition_weighted_work_items = lambda *args, **kwargs: [] + fake_fingerprinting.PublishingInputBundle = object + + class _FakeFingerprintingService: + def build_traceability(self, *args, **kwargs): + return object() + + def compute_scope_fingerprint(self, *args, **kwargs): + return "fake-fingerprint" + + fake_fingerprinting.FingerprintingService = _FakeFingerprintingService with _patched_module_registry( { @@ -80,6 +93,9 @@ def decorator(func): "policyengine_us_data": fake_policyengine, "policyengine_us_data.calibration": fake_calibration, "policyengine_us_data.calibration.local_h5": fake_local_h5, + "policyengine_us_data.calibration.local_h5.fingerprinting": ( + fake_fingerprinting + ), "policyengine_us_data.calibration.local_h5.partitioning": ( fake_partitioning ), diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py index 5aaca8a47..ed46c16f9 100644 --- a/tests/unit/test_pipeline.py +++ b/tests/unit/test_pipeline.py @@ -63,6 +63,23 @@ def test_from_dict(self): assert meta.status == "completed" assert meta.step_timings["build_datasets"]["status"] == "completed" + def test_from_dict_maps_legacy_fingerprint_to_regional_scope(self): + meta = RunMetadata.from_dict( + { + "run_id": "test", + "branch": "main", + "sha": "abc12345deadbeef", + "version": "1.72.3", + "start_time": "2026-03-19T12:00:00Z", + "status": "running", + "fingerprint": "legacy-fingerprint", + } + ) + + assert meta.fingerprint == "legacy-fingerprint" + assert meta.regional_fingerprint == "legacy-fingerprint" + assert meta.national_fingerprint is None + def test_roundtrip(self): meta = RunMetadata( run_id="1.72.3_abc12345_20260319_120000", @@ -79,6 +96,24 @@ def test_roundtrip(self): assert roundtripped.status == meta.status assert roundtripped.error == meta.error + def test_to_dict_keeps_legacy_fingerprint_alias_in_sync(self): + meta = RunMetadata( + run_id="test", + branch="main", + sha="abc", + version="1.0.0", + start_time="now", + status="running", + regional_fingerprint="regional-fp", + national_fingerprint="national-fp", + ) + + payload = meta.to_dict() + + assert payload["fingerprint"] == "regional-fp" + assert payload["regional_fingerprint"] == "regional-fp" + assert payload["national_fingerprint"] == "national-fp" + def test_step_timings_default_empty(self): meta = RunMetadata( run_id="test", From a1e441284d94111e40b1f501e1039496c2f52818 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 16 Apr 2026 01:19:44 +0200 Subject: [PATCH 2/3] Tighten local H5 fingerprinting boundary --- modal_app/local_area.py | 2 - modal_app/pipeline.py | 7 -- .../calibration/local_h5/fingerprinting.py | 43 ++++---- .../fixtures/test_local_h5_fingerprinting.py | 99 +++++-------------- .../test_local_h5_fingerprinting.py | 18 ++++ tests/unit/test_pipeline.py | 3 - 6 files changed, 69 insertions(+), 103 deletions(-) diff --git a/modal_app/local_area.py b/modal_app/local_area.py index e5045582d..97499f08b 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -1108,7 +1108,6 @@ def coordinate_national_publish( n_clones: int = 430, validate: bool = True, run_id: str = "", - expected_fingerprint: str = "", ) -> Dict: """Build and upload a national US.h5 from national weights.""" setup_gcp_credentials() @@ -1183,7 +1182,6 @@ def coordinate_national_publish( fingerprint = _resolve_scope_fingerprint( inputs=fingerprint_inputs, scope="national", - expected_fingerprint=expected_fingerprint, ) run_dir = staging_dir / run_id run_dir.mkdir(parents=True, exist_ok=True) diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index 02b79d983..5aa89f8fd 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -110,7 +110,6 @@ class RunMetadata: resume_history: list = field(default_factory=list) fingerprint: Optional[str] = None regional_fingerprint: Optional[str] = None - national_fingerprint: Optional[str] = None def __post_init__(self) -> None: if self.regional_fingerprint is None and self.fingerprint is not None: @@ -1009,7 +1008,6 @@ def run_pipeline( n_clones=n_clones, validate=True, run_id=run_id, - expected_fingerprint=meta.national_fingerprint or "", ) print( f" → coordinate_national_publish fc: {national_h5_handle.object_id}" @@ -1048,11 +1046,6 @@ def run_pipeline( else national_h5_result ) print(f" National H5: {national_msg}") - if isinstance(national_h5_result, dict) and national_h5_result.get( - "fingerprint" - ): - meta.national_fingerprint = national_h5_result["fingerprint"] - write_run_meta(meta, pipeline_volume) # ── Aggregate validation results ── _write_validation_diagnostics( diff --git a/policyengine_us_data/calibration/local_h5/fingerprinting.py b/policyengine_us_data/calibration/local_h5/fingerprinting.py index 6bff37af0..8f401e582 100644 --- a/policyengine_us_data/calibration/local_h5/fingerprinting.py +++ b/policyengine_us_data/calibration/local_h5/fingerprinting.py @@ -186,7 +186,11 @@ def _build_geography_identity( "source_kind": resolved.kind, "canonical_sha256": self._geography_loader.compute_canonical_checksum( weights_path=inputs.weights_path, - n_records=self._infer_n_records(inputs.source_dataset_path), + n_records=self._infer_n_records( + weights_path=inputs.weights_path, + source_dataset_path=inputs.source_dataset_path, + n_clones=inputs.n_clones, + ), n_clones=inputs.n_clones, geography_path=inputs.exact_geography_path, blocks_path=inputs.legacy_blocks_path, @@ -228,20 +232,23 @@ def _sha256_file(self, path: Path) -> str: digest.update(chunk) return f"sha256:{digest.hexdigest()}" - def _infer_n_records(self, source_dataset_path: Path) -> int: - import h5py - - with h5py.File(source_dataset_path, "r") as handle: - if "person" not in handle: - raise ValueError( - f"Unable to infer n_records from {source_dataset_path}: " - "missing 'person' entity" - ) - person_group = handle["person"] - first_dataset_name = next(iter(person_group.keys()), None) - if first_dataset_name is None: - raise ValueError( - f"Unable to infer n_records from {source_dataset_path}: " - "'person' entity is empty" - ) - return int(len(person_group[first_dataset_name])) + def _infer_n_records( + self, + *, + weights_path: Path, + source_dataset_path: Path, + n_clones: int | None, + ) -> int: + if n_clones is not None: + import numpy as np + + weights = np.load(weights_path, mmap_mode="r") + if len(weights) % n_clones == 0: + return int(len(weights) // n_clones) + + from policyengine_us import Microsimulation + + simulation = Microsimulation(dataset=str(source_dataset_path)) + return int( + len(simulation.calculate("household_id", map_to="household").values) + ) diff --git a/tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py b/tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py index d8bba2148..2ecffd000 100644 --- a/tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py +++ b/tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py @@ -2,11 +2,9 @@ from __future__ import annotations -import importlib.util +import importlib import json -import sys from pathlib import Path -from types import ModuleType import h5py import numpy as np @@ -17,91 +15,41 @@ __test__ = False +_FINGERPRINTING_EXPORTS = None -def _ensure_package(name: str, path: Path) -> None: - """Register a synthetic package so relative imports resolve locally.""" - package = sys.modules.get(name) - if package is None: - package = ModuleType(name) - package.__path__ = [str(path)] - sys.modules[name] = package - return - package.__path__ = [str(path)] - - -def _load_module(name: str, path: Path): - """Load one module from disk under a specific fully-qualified name.""" - - sys.modules.pop(name, None) - spec = importlib.util.spec_from_file_location(name, path) - module = importlib.util.module_from_spec(spec) - assert spec is not None - assert spec.loader is not None - sys.modules[name] = module - spec.loader.exec_module(module) - return module +def load_fingerprinting_exports(): + """Load the fingerprinting module without replacing shared package modules.""" + global _FINGERPRINTING_EXPORTS + if _FINGERPRINTING_EXPORTS is not None: + return _FINGERPRINTING_EXPORTS -def load_fingerprinting_exports(): - """Load the local H5 fingerprinting module under a synthetic package name.""" - - repo_root = Path(__file__).resolve().parents[4] - local_h5_root = ( - repo_root - / "policyengine_us_data" - / "calibration" - / "local_h5" - ) - calibration_root = repo_root / "policyengine_us_data" / "calibration" - storage_root = repo_root / "policyengine_us_data" / "storage" - package_name = "local_h5_fingerprinting_fixture" - policyengine_package = "policyengine_us_data" - calibration_package = "policyengine_us_data.calibration" - - for name in list(sys.modules): - if ( - name == package_name - or name.startswith(f"{package_name}.") - or name == policyengine_package - or name.startswith(f"{policyengine_package}.") - ): - sys.modules.pop(name, None) - - _ensure_package(package_name, local_h5_root) - _ensure_package(policyengine_package, repo_root / "policyengine_us_data") - _ensure_package(calibration_package, calibration_root) - _load_module( - "policyengine_us_data.storage", - storage_root / "__init__.py", + module = importlib.import_module( + "policyengine_us_data.calibration.local_h5.fingerprinting" ) - _load_module( - "policyengine_us_data.calibration.clone_and_assign", - calibration_root / "clone_and_assign.py", - ) - _load_module( - f"{package_name}.geography_loader", - local_h5_root / "geography_loader.py", - ) - module = _load_module( - f"{package_name}.fingerprinting", - local_h5_root / "fingerprinting.py", - ) - return { + _FINGERPRINTING_EXPORTS = { "module": module, "ArtifactIdentity": module.ArtifactIdentity, "FingerprintingService": module.FingerprintingService, "PublishingInputBundle": module.PublishingInputBundle, "TraceabilityBundle": module.TraceabilityBundle, } + return _FINGERPRINTING_EXPORTS -def write_source_dataset(path: Path, *, n_records: int) -> None: +def write_source_dataset( + path: Path, + *, + n_records: int, + person_records: int | None = None, +) -> None: """Write a minimal HDF5 dataset with a ``person`` entity.""" + person_count = person_records if person_records is not None else n_records with h5py.File(path, "w") as handle: person = handle.create_group("person") - person.create_dataset("person_id", data=np.arange(n_records, dtype=np.int32)) + person.create_dataset("person_id", data=np.arange(person_count, dtype=np.int32)) def write_run_config(path: Path, *, package_version: str = "1.0.0") -> None: @@ -127,6 +75,7 @@ def make_publishing_inputs( *, tmp_path: Path, n_records: int = 2, + person_records: int | None = None, n_clones: int = 2, seed: int = 42, package_version: str = "1.0.0", @@ -140,8 +89,12 @@ def make_publishing_inputs( geography_path = tmp_path / "geography_assignment.npz" run_config_path = tmp_path / "unified_run_config.json" - np.save(weights_path, np.array([1.0, 2.0, 3.0])) - write_source_dataset(dataset_path, n_records=n_records) + np.save(weights_path, np.arange(n_records * n_clones, dtype=float) + 1.0) + write_source_dataset( + dataset_path, + n_records=n_records, + person_records=person_records, + ) write_artifact_file(db_path, b"fake-db") write_saved_geography( geography_path, diff --git a/tests/unit/calibration/test_local_h5_fingerprinting.py b/tests/unit/calibration/test_local_h5_fingerprinting.py index 180e242f1..66f288738 100644 --- a/tests/unit/calibration/test_local_h5_fingerprinting.py +++ b/tests/unit/calibration/test_local_h5_fingerprinting.py @@ -78,3 +78,21 @@ def test_scope_fingerprint_changes_when_relevant_provenance_changes(tmp_path): ) assert first != second + + +def test_traceability_uses_weight_derived_household_count_for_geography(tmp_path): + inputs = make_publishing_inputs( + PublishingInputBundle, + tmp_path=tmp_path, + n_records=2, + person_records=5, + n_clones=2, + ) + + service = FingerprintingService() + traceability = service.build_traceability(inputs=inputs, scope="regional") + + assert traceability.exact_geography is not None + assert traceability.exact_geography.metadata["canonical_sha256"].startswith( + "sha256:" + ) diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py index ed46c16f9..d9b31714e 100644 --- a/tests/unit/test_pipeline.py +++ b/tests/unit/test_pipeline.py @@ -78,7 +78,6 @@ def test_from_dict_maps_legacy_fingerprint_to_regional_scope(self): assert meta.fingerprint == "legacy-fingerprint" assert meta.regional_fingerprint == "legacy-fingerprint" - assert meta.national_fingerprint is None def test_roundtrip(self): meta = RunMetadata( @@ -105,14 +104,12 @@ def test_to_dict_keeps_legacy_fingerprint_alias_in_sync(self): start_time="now", status="running", regional_fingerprint="regional-fp", - national_fingerprint="national-fp", ) payload = meta.to_dict() assert payload["fingerprint"] == "regional-fp" assert payload["regional_fingerprint"] == "regional-fp" - assert payload["national_fingerprint"] == "national-fp" def test_step_timings_default_empty(self): meta = RunMetadata( From 606da074c36bf3879ef1342121c51d4b5592a590 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 25 Apr 2026 01:02:05 +0200 Subject: [PATCH 3/3] Add local H5 traceability tests --- .github/workflows/pr.yaml | 14 ++ tests/integration/local_h5/fixtures.py | 203 ++++++++++++++++++ .../test_modal_local_area_traceability.py | 64 ++++++ .../local_h5/test_traceability_contract.py | 88 ++++++++ .../test_worker_script_tiny_fixture.py | 115 ++++++++++ .../test_local_h5_fingerprinting.py | 51 +++++ tests/unit/fixtures/test_modal_local_area.py | 88 ++++---- tests/unit/test_modal_local_area.py | 163 ++++++++++++++ tests/unit/test_pipeline.py | 36 +++- 9 files changed, 780 insertions(+), 42 deletions(-) create mode 100644 tests/integration/local_h5/fixtures.py create mode 100644 tests/integration/local_h5/test_modal_local_area_traceability.py create mode 100644 tests/integration/local_h5/test_traceability_contract.py create mode 100644 tests/integration/local_h5/test_worker_script_tiny_fixture.py diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 797ca19b0..fccc63078 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -86,6 +86,19 @@ jobs: env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + local-h5-integration-tests: + runs-on: ubuntu-latest + needs: [check-fork, lint, unit-tests] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.14" + - uses: astral-sh/setup-uv@v5 + - run: uv sync --dev + - name: Run local H5 integration tests + run: uv run pytest --noconftest tests/integration/local_h5/ -v + optimized-integration-tests: runs-on: ubuntu-latest needs: @@ -95,6 +108,7 @@ jobs: lint, check-changelog, unit-tests, + local-h5-integration-tests, smoke-test, docs-build, ] diff --git a/tests/integration/local_h5/fixtures.py b/tests/integration/local_h5/fixtures.py new file mode 100644 index 000000000..3edf0f020 --- /dev/null +++ b/tests/integration/local_h5/fixtures.py @@ -0,0 +1,203 @@ +"""Shared tiny-artifact fixtures for local H5 integration tests.""" + +from __future__ import annotations + +import json +import pickle +import shutil +import sqlite3 +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +import numpy as np + +from policyengine_us_data.calibration.clone_and_assign import ( + GeographyAssignment, + save_geography, +) +from policyengine_us_data.calibration.local_h5.requests import ( + AreaBuildRequest, + AreaFilter, +) + +FIXTURE_DATASET_PATH = Path(__file__).resolve().parents[1] / "test_fixture_50hh.h5" +DISTRICT_GEOID = "3701" +COUNTY_FIPS = "37183" +STATE_FIPS = 37 +N_CLONES = 1 +SEED = 42 +VERSION = "0.0.0" + + +@dataclass(frozen=True) +class LocalH5Artifacts: + dataset_path: Path + weights_path: Path + db_path: Path + run_config_path: Path + geography_path: Path + calibration_package_path: Path + geography: GeographyAssignment + n_records: int + n_clones: int + + +@lru_cache(maxsize=1) +def fixture_household_count() -> int: + from policyengine_us import Microsimulation + + sim = Microsimulation(dataset=str(FIXTURE_DATASET_PATH)) + try: + return int(len(sim.calculate("household_id", map_to="household").values)) + finally: + del sim + + +def base_geography(*, n_records: int, n_clones: int = N_CLONES) -> GeographyAssignment: + total_rows = n_records * n_clones + block_geoids = np.array( + [f"{COUNTY_FIPS}{i:06d}{i:04d}"[:15] for i in range(total_rows)], + dtype="U15", + ) + return GeographyAssignment( + block_geoid=block_geoids, + cd_geoid=np.full(total_rows, DISTRICT_GEOID, dtype="U4"), + county_fips=np.full(total_rows, COUNTY_FIPS, dtype="U5"), + state_fips=np.full(total_rows, STATE_FIPS, dtype=np.int32), + n_records=n_records, + n_clones=n_clones, + ) + + +def seed_local_h5_artifacts( + tmp_path: Path, + *, + n_clones: int = N_CLONES, +) -> LocalH5Artifacts: + artifact_dir = tmp_path / "artifacts" + if artifact_dir.exists(): + shutil.rmtree(artifact_dir) + artifact_dir.mkdir(parents=True, exist_ok=True) + + dataset_path = artifact_dir / "source.h5" + weights_path = artifact_dir / "calibration_weights.npy" + db_path = artifact_dir / "policy_data.db" + run_config_path = artifact_dir / "unified_run_config.json" + geography_path = artifact_dir / "geography_assignment.npz" + calibration_package_path = artifact_dir / "calibration_package.pkl" + + shutil.copy2(FIXTURE_DATASET_PATH, dataset_path) + n_records = fixture_household_count() + np.save(weights_path, np.ones(n_records * n_clones, dtype=np.float32)) + + geography = base_geography(n_records=n_records, n_clones=n_clones) + save_geography(geography, geography_path) + + with open(calibration_package_path, "wb") as handle: + pickle.dump( + { + "block_geoid": geography.block_geoid, + "cd_geoid": geography.cd_geoid, + "metadata": { + "git_commit": "deadbeefcafebabe", + "git_branch": "main", + "git_dirty": False, + "package_version": VERSION, + }, + }, + handle, + protocol=pickle.HIGHEST_PROTOCOL, + ) + + conn = sqlite3.connect(db_path) + try: + conn.execute( + """ + CREATE TABLE stratum_constraints ( + stratum_id INTEGER, + constraint_variable TEXT, + value TEXT + ) + """ + ) + conn.execute( + """ + INSERT INTO stratum_constraints (stratum_id, constraint_variable, value) + VALUES (?, ?, ?) + """, + (1, "congressional_district_geoid", DISTRICT_GEOID), + ) + conn.commit() + finally: + conn.close() + + run_config_path.write_text( + json.dumps( + { + "git_commit": "deadbeefcafebabe", + "git_branch": "main", + "git_dirty": False, + "package_version": VERSION, + } + ) + ) + + return LocalH5Artifacts( + dataset_path=dataset_path, + weights_path=weights_path, + db_path=db_path, + run_config_path=run_config_path, + geography_path=geography_path, + calibration_package_path=calibration_package_path, + geography=geography, + n_records=n_records, + n_clones=n_clones, + ) + + +def build_request( + area_type: str, *, geography: GeographyAssignment +) -> AreaBuildRequest: + if area_type == "district": + return AreaBuildRequest( + area_type="district", + area_id="NC-01", + display_name="NC-01", + output_relative_path="districts/NC-01.h5", + filters=( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=(DISTRICT_GEOID,), + ), + ), + validation_geo_level="district", + validation_geographic_ids=(DISTRICT_GEOID,), + ) + if area_type == "state": + return AreaBuildRequest( + area_type="state", + area_id="NC", + display_name="NC", + output_relative_path="states/NC.h5", + filters=( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=(DISTRICT_GEOID,), + ), + ), + validation_geo_level="state", + validation_geographic_ids=(str(STATE_FIPS),), + ) + if area_type == "national": + return AreaBuildRequest( + area_type="national", + area_id="US", + display_name="US", + output_relative_path="national/US.h5", + validation_geo_level="national", + validation_geographic_ids=("US",), + ) + raise ValueError(f"Unsupported area_type for test fixture: {area_type}") diff --git a/tests/integration/local_h5/test_modal_local_area_traceability.py b/tests/integration/local_h5/test_modal_local_area_traceability.py new file mode 100644 index 000000000..13d86ad30 --- /dev/null +++ b/tests/integration/local_h5/test_modal_local_area_traceability.py @@ -0,0 +1,64 @@ +from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintingService, +) + +from tests.integration.local_h5.fixtures import SEED, VERSION, seed_local_h5_artifacts +from tests.unit.fixtures.test_modal_local_area import load_local_area_module + + +def test_local_area_helpers_match_publish_traceability_contract(tmp_path): + local_area = load_local_area_module(stub_policyengine=False) + artifacts = seed_local_h5_artifacts(tmp_path) + + inputs = local_area._build_publishing_input_bundle( + weights_path=artifacts.weights_path, + dataset_path=artifacts.dataset_path, + db_path=artifacts.db_path, + geography_path=artifacts.geography_path, + calibration_package_path=artifacts.calibration_package_path, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + + helper_fingerprint = local_area._resolve_scope_fingerprint( + inputs=inputs, + scope="regional", + ) + service = FingerprintingService() + service_fingerprint = service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope="regional") + ) + + assert helper_fingerprint == service_fingerprint + + +def test_local_area_scope_helper_distinguishes_regional_and_national(tmp_path): + local_area = load_local_area_module(stub_policyengine=False) + artifacts = seed_local_h5_artifacts(tmp_path) + + inputs = local_area._build_publishing_input_bundle( + weights_path=artifacts.weights_path, + dataset_path=artifacts.dataset_path, + db_path=artifacts.db_path, + geography_path=artifacts.geography_path, + calibration_package_path=artifacts.calibration_package_path, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + + regional = local_area._resolve_scope_fingerprint( + inputs=inputs, + scope="regional", + ) + national = local_area._resolve_scope_fingerprint( + inputs=inputs, + scope="national", + ) + + assert regional != national diff --git a/tests/integration/local_h5/test_traceability_contract.py b/tests/integration/local_h5/test_traceability_contract.py new file mode 100644 index 000000000..65fa0b678 --- /dev/null +++ b/tests/integration/local_h5/test_traceability_contract.py @@ -0,0 +1,88 @@ +from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintingService, + PublishingInputBundle, +) + +from tests.integration.local_h5.fixtures import SEED, VERSION, seed_local_h5_artifacts + + +def _fingerprint_for(*, inputs, scope: str = "regional") -> str: + service = FingerprintingService() + return service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope=scope) + ) + + +def test_saved_geography_bundle_builds_traceability_with_stable_fingerprint(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path) + inputs = PublishingInputBundle( + weights_path=artifacts.weights_path, + source_dataset_path=artifacts.dataset_path, + target_db_path=artifacts.db_path, + exact_geography_path=artifacts.geography_path, + calibration_package_path=None, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + + first = _fingerprint_for(inputs=inputs) + second = _fingerprint_for(inputs=inputs) + + assert first == second + + +def test_package_geography_bundle_builds_traceability_with_stable_fingerprint(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path) + inputs = PublishingInputBundle( + weights_path=artifacts.weights_path, + source_dataset_path=artifacts.dataset_path, + target_db_path=artifacts.db_path, + exact_geography_path=None, + calibration_package_path=artifacts.calibration_package_path, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + + first = _fingerprint_for(inputs=inputs) + second = _fingerprint_for(inputs=inputs) + + assert first == second + + +def test_saved_and_package_geography_share_the_same_resumability_identity(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path) + saved_inputs = PublishingInputBundle( + weights_path=artifacts.weights_path, + source_dataset_path=artifacts.dataset_path, + target_db_path=artifacts.db_path, + exact_geography_path=artifacts.geography_path, + calibration_package_path=None, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + package_inputs = PublishingInputBundle( + weights_path=artifacts.weights_path, + source_dataset_path=artifacts.dataset_path, + target_db_path=artifacts.db_path, + exact_geography_path=None, + calibration_package_path=artifacts.calibration_package_path, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + + saved_fingerprint = _fingerprint_for(inputs=saved_inputs) + package_fingerprint = _fingerprint_for(inputs=package_inputs) + + assert saved_fingerprint == package_fingerprint diff --git a/tests/integration/local_h5/test_worker_script_tiny_fixture.py b/tests/integration/local_h5/test_worker_script_tiny_fixture.py new file mode 100644 index 000000000..12b6a0426 --- /dev/null +++ b/tests/integration/local_h5/test_worker_script_tiny_fixture.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + +import pytest + +from tests.integration.local_h5.fixtures import ( + build_request, + seed_local_h5_artifacts, +) + +pytest.importorskip("scipy") +pytest.importorskip("spm_calculator") + + +def _run_worker( + *, + request, + artifacts, + output_dir: Path, + use_saved_geography: bool = False, + use_package_geography: bool = False, +) -> dict: + cmd = [ + sys.executable, + "-m", + "modal_app.worker_script", + "--requests-json", + json.dumps([request.to_dict()]), + "--weights-path", + str(artifacts.weights_path), + "--dataset-path", + str(artifacts.dataset_path), + "--db-path", + str(artifacts.db_path), + "--output-dir", + str(output_dir), + "--n-clones", + str(artifacts.n_clones), + "--no-validate", + ] + if use_saved_geography: + cmd.extend(["--geography-path", str(artifacts.geography_path)]) + if use_package_geography: + cmd.extend( + [ + "--calibration-package-path", + str(artifacts.calibration_package_path), + ] + ) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True, + ) + return json.loads(result.stdout) + + +def test_worker_builds_district_h5_from_saved_geography(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path / "district") + request = build_request("district", geography=artifacts.geography) + output_dir = tmp_path / "district-out" + + result = _run_worker( + request=request, + artifacts=artifacts, + output_dir=output_dir, + use_saved_geography=True, + ) + + assert result["failed"] == [] + assert result["errors"] == [] + assert result["completed"] == [f"district:{request.area_id}"] + assert (output_dir / request.output_relative_path).exists() + + +def test_worker_builds_state_h5_from_package_geography(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path / "state") + request = build_request("state", geography=artifacts.geography) + output_dir = tmp_path / "state-out" + + result = _run_worker( + request=request, + artifacts=artifacts, + output_dir=output_dir, + use_package_geography=True, + ) + + assert result["failed"] == [] + assert result["errors"] == [] + assert result["completed"] == [f"state:{request.area_id}"] + assert (output_dir / request.output_relative_path).exists() + + +def test_worker_builds_national_h5_from_package_geography(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path / "national") + request = build_request("national", geography=artifacts.geography) + output_dir = tmp_path / "national-out" + + result = _run_worker( + request=request, + artifacts=artifacts, + output_dir=output_dir, + use_package_geography=True, + ) + + assert result["failed"] == [] + assert result["errors"] == [] + assert result["completed"] == ["national:US"] + assert (output_dir / request.output_relative_path).exists() diff --git a/tests/unit/calibration/test_local_h5_fingerprinting.py b/tests/unit/calibration/test_local_h5_fingerprinting.py index 66f288738..08d2b593b 100644 --- a/tests/unit/calibration/test_local_h5_fingerprinting.py +++ b/tests/unit/calibration/test_local_h5_fingerprinting.py @@ -96,3 +96,54 @@ def test_traceability_uses_weight_derived_household_count_for_geography(tmp_path assert traceability.exact_geography.metadata["canonical_sha256"].startswith( "sha256:" ) + + +def test_resumability_material_prefers_canonical_geography_checksum(tmp_path): + inputs = make_publishing_inputs(PublishingInputBundle, tmp_path=tmp_path) + + service = FingerprintingService() + traceability = service.build_traceability(inputs=inputs, scope="regional") + resumability = traceability.resumability_material() + + assert traceability.exact_geography is not None + assert ( + resumability["exact_geography_sha256"] + == traceability.exact_geography.metadata["canonical_sha256"] + ) + + +def test_traceability_handles_missing_optional_artifacts(tmp_path): + inputs = make_publishing_inputs(PublishingInputBundle, tmp_path=tmp_path) + standalone_weights_path = tmp_path / "standalone" / "weights.npy" + standalone_weights_path.parent.mkdir(parents=True, exist_ok=True) + standalone_weights_path.write_bytes(inputs.weights_path.read_bytes()) + inputs = PublishingInputBundle( + weights_path=standalone_weights_path, + source_dataset_path=inputs.source_dataset_path, + target_db_path=None, + exact_geography_path=None, + calibration_package_path=None, + run_config_path=None, + run_id=inputs.run_id, + version=inputs.version, + n_clones=inputs.n_clones, + seed=inputs.seed, + legacy_blocks_path=None, + ) + + service = FingerprintingService() + traceability = service.build_traceability(inputs=inputs, scope="regional") + + assert traceability.target_db is None + assert traceability.exact_geography is None + assert traceability.calibration_package is None + assert traceability.run_config is None + assert traceability.code_version == { + "git_commit": None, + "git_branch": None, + "git_dirty": None, + } + assert traceability.model_build == { + "locked_version": None, + "git_commit": None, + } diff --git a/tests/unit/fixtures/test_modal_local_area.py b/tests/unit/fixtures/test_modal_local_area.py index 935da8d6e..db9d0e621 100644 --- a/tests/unit/fixtures/test_modal_local_area.py +++ b/tests/unit/fixtures/test_modal_local_area.py @@ -31,22 +31,10 @@ def _patched_module_registry(overrides: dict[str, ModuleType]): sys.modules[name] = module -def load_local_area_module(): +def load_local_area_module(*, stub_policyengine: bool = True): """Import `modal_app.local_area` with scoped fake Modal dependencies.""" fake_modal = ModuleType("modal") - fake_policyengine = ModuleType("policyengine_us_data") - fake_calibration = ModuleType("policyengine_us_data.calibration") - fake_local_h5 = ModuleType("policyengine_us_data.calibration.local_h5") - fake_partitioning = ModuleType( - "policyengine_us_data.calibration.local_h5.partitioning" - ) - fake_fingerprinting = ModuleType( - "policyengine_us_data.calibration.local_h5.fingerprinting" - ) - fake_policyengine.__path__ = [] - fake_calibration.__path__ = [] - fake_local_h5.__path__ = [] class _FakeApp: def __init__(self, *args, **kwargs): @@ -73,32 +61,50 @@ def decorator(func): fake_resilience = ModuleType("modal_app.resilience") fake_resilience.reconcile_run_dir_fingerprint = lambda *args, **kwargs: None - fake_partitioning.partition_weighted_work_items = lambda *args, **kwargs: [] - fake_fingerprinting.PublishingInputBundle = object - - class _FakeFingerprintingService: - def build_traceability(self, *args, **kwargs): - return object() - - def compute_scope_fingerprint(self, *args, **kwargs): - return "fake-fingerprint" - - fake_fingerprinting.FingerprintingService = _FakeFingerprintingService - - with _patched_module_registry( - { - "modal": fake_modal, - "modal_app.images": fake_images, - "modal_app.resilience": fake_resilience, - "policyengine_us_data": fake_policyengine, - "policyengine_us_data.calibration": fake_calibration, - "policyengine_us_data.calibration.local_h5": fake_local_h5, - "policyengine_us_data.calibration.local_h5.fingerprinting": ( - fake_fingerprinting - ), - "policyengine_us_data.calibration.local_h5.partitioning": ( - fake_partitioning - ), - } - ): + + overrides = { + "modal": fake_modal, + "modal_app.images": fake_images, + "modal_app.resilience": fake_resilience, + } + + if stub_policyengine: + fake_policyengine = ModuleType("policyengine_us_data") + fake_calibration = ModuleType("policyengine_us_data.calibration") + fake_local_h5 = ModuleType("policyengine_us_data.calibration.local_h5") + fake_partitioning = ModuleType( + "policyengine_us_data.calibration.local_h5.partitioning" + ) + fake_fingerprinting = ModuleType( + "policyengine_us_data.calibration.local_h5.fingerprinting" + ) + fake_policyengine.__path__ = [] + fake_calibration.__path__ = [] + fake_local_h5.__path__ = [] + fake_partitioning.partition_weighted_work_items = lambda *args, **kwargs: [] + fake_fingerprinting.PublishingInputBundle = object + + class _FakeFingerprintingService: + def build_traceability(self, *args, **kwargs): + return object() + + def compute_scope_fingerprint(self, *args, **kwargs): + return "fake-fingerprint" + + fake_fingerprinting.FingerprintingService = _FakeFingerprintingService + overrides.update( + { + "policyengine_us_data": fake_policyengine, + "policyengine_us_data.calibration": fake_calibration, + "policyengine_us_data.calibration.local_h5": fake_local_h5, + "policyengine_us_data.calibration.local_h5.fingerprinting": ( + fake_fingerprinting + ), + "policyengine_us_data.calibration.local_h5.partitioning": ( + fake_partitioning + ), + } + ) + + with _patched_module_registry(overrides): return importlib.import_module("modal_app.local_area") diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index 0e3cd9fd6..e8128db71 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -1,3 +1,5 @@ +from pathlib import Path + from tests.unit.fixtures.test_modal_local_area import load_local_area_module @@ -28,3 +30,164 @@ def test_build_promote_publish_script_finalizes_complete_release(): assert "should_finalize_local_area_release" in script assert "create_tag=should_finalize" in script assert "upload_manifest(" in script + + +def test_build_publishing_input_bundle_preserves_traceability_inputs(): + local_area = load_local_area_module(stub_policyengine=False) + + bundle = local_area._build_publishing_input_bundle( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + db_path=Path("/tmp/policy_data.db"), + geography_path=Path("/tmp/geography_assignment.npz"), + calibration_package_path=Path("/tmp/calibration_package.pkl"), + run_config_path=Path("/tmp/unified_run_config.json"), + run_id="run-123", + version="1.2.3", + n_clones=4, + seed=42, + legacy_blocks_path=Path("/tmp/stacked_blocks.npy"), + ) + + assert bundle.weights_path == Path("/tmp/calibration_weights.npy") + assert bundle.source_dataset_path == Path("/tmp/source.h5") + assert bundle.target_db_path == Path("/tmp/policy_data.db") + assert bundle.exact_geography_path == Path("/tmp/geography_assignment.npz") + assert bundle.calibration_package_path == Path("/tmp/calibration_package.pkl") + assert bundle.run_config_path == Path("/tmp/unified_run_config.json") + assert bundle.run_id == "run-123" + assert bundle.version == "1.2.3" + assert bundle.n_clones == 4 + assert bundle.seed == 42 + assert bundle.legacy_blocks_path == Path("/tmp/stacked_blocks.npy") + + +def test_resolve_scope_fingerprint_computes_when_no_pin(monkeypatch): + local_area = load_local_area_module(stub_policyengine=False) + + seen = {} + + class FakeFingerprintingService: + def build_traceability(self, *, inputs, scope): + seen["inputs"] = inputs + seen["scope"] = scope + return {"scope": scope, "run_id": inputs.run_id} + + def compute_scope_fingerprint(self, traceability): + seen["traceability"] = traceability + return "computed-fingerprint" + + monkeypatch.setattr( + local_area, + "FingerprintingService", + FakeFingerprintingService, + ) + + bundle = local_area._build_publishing_input_bundle( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + db_path=None, + geography_path=None, + calibration_package_path=None, + run_config_path=None, + run_id="run-123", + version="1.2.3", + n_clones=2, + seed=42, + ) + + fingerprint = local_area._resolve_scope_fingerprint( + inputs=bundle, + scope="regional", + ) + + assert fingerprint == "computed-fingerprint" + assert seen["inputs"] == bundle + assert seen["scope"] == "regional" + assert seen["traceability"] == {"scope": "regional", "run_id": "run-123"} + + +def test_resolve_scope_fingerprint_preserves_matching_pin(monkeypatch, capsys): + local_area = load_local_area_module(stub_policyengine=False) + + class FakeFingerprintingService: + def build_traceability(self, *, inputs, scope): + return scope + + def compute_scope_fingerprint(self, traceability): + return "pinned-fingerprint" + + monkeypatch.setattr( + local_area, + "FingerprintingService", + FakeFingerprintingService, + ) + + bundle = local_area._build_publishing_input_bundle( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + db_path=None, + geography_path=None, + calibration_package_path=None, + run_config_path=None, + run_id="run-123", + version="1.2.3", + n_clones=2, + seed=42, + ) + + fingerprint = local_area._resolve_scope_fingerprint( + inputs=bundle, + scope="regional", + expected_fingerprint="pinned-fingerprint", + ) + + captured = capsys.readouterr() + assert fingerprint == "pinned-fingerprint" + assert "Using pinned fingerprint from pipeline" in captured.out + + +def test_resolve_scope_fingerprint_warns_and_preserves_mismatched_pin( + monkeypatch, capsys +): + local_area = load_local_area_module(stub_policyengine=False) + + class FakeFingerprintingService: + def build_traceability(self, *, inputs, scope): + return scope + + def compute_scope_fingerprint(self, traceability): + return "computed-fingerprint" + + monkeypatch.setattr( + local_area, + "FingerprintingService", + FakeFingerprintingService, + ) + + bundle = local_area._build_publishing_input_bundle( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + db_path=None, + geography_path=None, + calibration_package_path=None, + run_config_path=None, + run_id="run-123", + version="1.2.3", + n_clones=2, + seed=42, + ) + + fingerprint = local_area._resolve_scope_fingerprint( + inputs=bundle, + scope="national", + expected_fingerprint="legacy-fingerprint", + ) + + captured = capsys.readouterr() + assert fingerprint == "legacy-fingerprint" + assert "Pinned fingerprint differs from current national scope fingerprint" in ( + captured.out + ) + assert "legacy-fingerprint" in captured.out + assert "computed-fingerprint" in captured.out diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py index d9b31714e..5458c53a7 100644 --- a/tests/unit/test_pipeline.py +++ b/tests/unit/test_pipeline.py @@ -8,7 +8,7 @@ modal = pytest.importorskip("modal") -from modal_app.pipeline import ( +from modal_app.pipeline import ( # noqa: E402 RunMetadata, _step_completed, _record_step, @@ -79,6 +79,23 @@ def test_from_dict_maps_legacy_fingerprint_to_regional_scope(self): assert meta.fingerprint == "legacy-fingerprint" assert meta.regional_fingerprint == "legacy-fingerprint" + def test_from_dict_keeps_explicit_regional_fingerprint_when_both_present(self): + meta = RunMetadata.from_dict( + { + "run_id": "test", + "branch": "main", + "sha": "abc12345deadbeef", + "version": "1.72.3", + "start_time": "2026-03-19T12:00:00Z", + "status": "running", + "fingerprint": "legacy-fingerprint", + "regional_fingerprint": "regional-fingerprint", + } + ) + + assert meta.fingerprint == "legacy-fingerprint" + assert meta.regional_fingerprint == "regional-fingerprint" + def test_roundtrip(self): meta = RunMetadata( run_id="1.72.3_abc12345_20260319_120000", @@ -111,6 +128,23 @@ def test_to_dict_keeps_legacy_fingerprint_alias_in_sync(self): assert payload["fingerprint"] == "regional-fp" assert payload["regional_fingerprint"] == "regional-fp" + def test_to_dict_preserves_distinct_explicit_regional_fingerprint(self): + meta = RunMetadata( + run_id="test", + branch="main", + sha="abc", + version="1.0.0", + start_time="now", + status="running", + fingerprint="legacy-fp", + regional_fingerprint="regional-fp", + ) + + payload = meta.to_dict() + + assert payload["fingerprint"] == "legacy-fp" + assert payload["regional_fingerprint"] == "regional-fp" + def test_step_timings_default_empty(self): meta = RunMetadata( run_id="test",