diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 797ca19b0..fccc63078 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -86,6 +86,19 @@ jobs: env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + local-h5-integration-tests: + runs-on: ubuntu-latest + needs: [check-fork, lint, unit-tests] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.14" + - uses: astral-sh/setup-uv@v5 + - run: uv sync --dev + - name: Run local H5 integration tests + run: uv run pytest --noconftest tests/integration/local_h5/ -v + optimized-integration-tests: runs-on: ubuntu-latest needs: @@ -95,6 +108,7 @@ jobs: lint, check-changelog, unit-tests, + local-h5-integration-tests, smoke-test, docs-build, ] diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 0beafee5c..97499f08b 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -29,6 +29,10 @@ from modal_app.images import cpu_image as image # noqa: E402 from modal_app.resilience import reconcile_run_dir_fingerprint # noqa: E402 +from policyengine_us_data.calibration.local_h5.fingerprinting import ( # noqa: E402 + FingerprintingService, + PublishingInputBundle, +) from policyengine_us_data.calibration.local_h5.partitioning import ( # noqa: E402 partition_weighted_work_items, ) @@ -311,6 +315,65 @@ def get_version() -> str: return pyproject["project"]["version"] +def _build_publishing_input_bundle( + *, + weights_path: Path, + dataset_path: Path, + db_path: Path | None, + geography_path: Path | None, + calibration_package_path: Path | None, + run_config_path: Path | None, + run_id: str, + version: str, + n_clones: int | None, + seed: int, + legacy_blocks_path: Path | None = None, +) -> PublishingInputBundle: + """Build the normalized coordinator input bundle for one publish scope.""" + + return PublishingInputBundle( + weights_path=weights_path, + source_dataset_path=dataset_path, + target_db_path=db_path, + exact_geography_path=geography_path, + calibration_package_path=calibration_package_path, + run_config_path=run_config_path, + run_id=run_id, + version=version, + n_clones=n_clones, + seed=seed, + legacy_blocks_path=legacy_blocks_path, + ) + + +def _resolve_scope_fingerprint( + *, + inputs: PublishingInputBundle, + scope: str, + expected_fingerprint: str = "", +) -> str: + """Compute the scope fingerprint while preserving pinned resume values.""" + + service = FingerprintingService() + traceability = service.build_traceability(inputs=inputs, scope=scope) + computed_fingerprint = service.compute_scope_fingerprint(traceability) + if expected_fingerprint: + if expected_fingerprint != computed_fingerprint: + print( + "WARNING: Pinned fingerprint differs from current " + f"{scope} scope fingerprint. " + "Preserving pinned value for backward-compatible resume.\n" + f" Pinned: {expected_fingerprint}\n" + f" Current: {computed_fingerprint}" + ) + else: + print( + f"Using pinned fingerprint from pipeline: {expected_fingerprint}" + ) + return expected_fingerprint + return computed_fingerprint + + def partition_work( work_items: List[Dict], num_workers: int, @@ -836,45 +899,26 @@ def coordinate_publish( validate = False # Fingerprint-based cache invalidation - if expected_fingerprint: - fingerprint = expected_fingerprint - print(f"Using pinned fingerprint from pipeline: {fingerprint}") - else: - geography_path_expr = ( - f'Path("{geography_path}")' if geography_path.exists() else "None" - ) - package_path_expr = ( - f'Path("{calibration_package_path}")' - if calibration_package_path.exists() - else "None" - ) - fp_result = subprocess.run( - _python_cmd( - "-c", - f""" -from pathlib import Path -from policyengine_us_data.calibration.publish_local_area import ( - compute_input_fingerprint, -) -print( - compute_input_fingerprint( - Path("{weights_path}"), - Path("{dataset_path}"), - {n_clones}, + fingerprint_inputs = _build_publishing_input_bundle( + weights_path=weights_path, + dataset_path=dataset_path, + db_path=db_path, + geography_path=geography_path, + calibration_package_path=( + calibration_package_path if calibration_package_path.exists() else None + ), + run_config_path=config_json_path if config_json_path.exists() else None, + run_id=run_id, + version=version, + n_clones=n_clones, seed=42, - geography_path={geography_path_expr}, - calibration_package_path={package_path_expr}, + legacy_blocks_path=artifacts / "stacked_blocks.npy", + ) + fingerprint = _resolve_scope_fingerprint( + inputs=fingerprint_inputs, + scope="regional", + expected_fingerprint=expected_fingerprint, ) -) -""", - ), - capture_output=True, - text=True, - env=os.environ.copy(), - ) - if fp_result.returncode != 0: - raise RuntimeError(f"Failed to compute fingerprint: {fp_result.stderr}") - fingerprint = fp_result.stdout.strip() reconcile_action = reconcile_run_dir_fingerprint(run_dir, fingerprint) if reconcile_action == "resume": print(f"Inputs unchanged ({fingerprint}), resuming...") @@ -1123,6 +1167,22 @@ def coordinate_national_publish( "geography_assignment.npz": "national_geography_assignment.npz", }, ) + fingerprint_inputs = _build_publishing_input_bundle( + weights_path=weights_path, + dataset_path=dataset_path, + db_path=db_path, + geography_path=geography_path, + calibration_package_path=None, + run_config_path=config_json_path if config_json_path.exists() else None, + run_id=run_id, + version=version, + n_clones=n_clones, + seed=42, + ) + fingerprint = _resolve_scope_fingerprint( + inputs=fingerprint_inputs, + scope="national", + ) run_dir = staging_dir / run_id run_dir.mkdir(parents=True, exist_ok=True) @@ -1224,6 +1284,7 @@ def coordinate_national_publish( f"{version}. Run main_national_promote to publish." ), "run_id": run_id, + "fingerprint": fingerprint, "national_validation": national_validation_output, } diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index c02d6f10e..5aa89f8fd 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -109,12 +109,25 @@ class RunMetadata: error: Optional[str] = None resume_history: list = field(default_factory=list) fingerprint: Optional[str] = None + regional_fingerprint: Optional[str] = None + + def __post_init__(self) -> None: + if self.regional_fingerprint is None and self.fingerprint is not None: + self.regional_fingerprint = self.fingerprint + if self.fingerprint is None and self.regional_fingerprint is not None: + self.fingerprint = self.regional_fingerprint def to_dict(self) -> dict: - return asdict(self) + data = asdict(self) + if data.get("fingerprint") is None and data.get("regional_fingerprint") is not None: + data["fingerprint"] = data["regional_fingerprint"] + return data @classmethod def from_dict(cls, data: dict) -> "RunMetadata": + data = dict(data) + if data.get("regional_fingerprint") is None and data.get("fingerprint") is not None: + data["regional_fingerprint"] = data["fingerprint"] return cls(**data) @@ -981,7 +994,9 @@ def run_pipeline( n_clones=n_clones, validate=True, run_id=run_id, - expected_fingerprint=meta.fingerprint or "", + expected_fingerprint=( + meta.regional_fingerprint or meta.fingerprint or "" + ), ) print(f" → coordinate_publish fc: {regional_h5_handle.object_id}") @@ -1017,6 +1032,7 @@ def run_pipeline( if isinstance(regional_h5_result, dict) and regional_h5_result.get( "fingerprint" ): + meta.regional_fingerprint = regional_h5_result["fingerprint"] meta.fingerprint = regional_h5_result["fingerprint"] write_run_meta(meta, pipeline_volume) diff --git a/policyengine_us_data/calibration/local_h5/__init__.py b/policyengine_us_data/calibration/local_h5/__init__.py index f69663eb0..96ec7258f 100644 --- a/policyengine_us_data/calibration/local_h5/__init__.py +++ b/policyengine_us_data/calibration/local_h5/__init__.py @@ -3,5 +3,5 @@ Modules in this package should land only when they become active runtime seams rather than speculative placeholders. The current early slices introduce ``partitioning.py``, ``requests.py``, ``area_catalog.py``, -and ``geography_loader.py``. +``fingerprinting.py``, and ``geography_loader.py``. """ diff --git a/policyengine_us_data/calibration/local_h5/fingerprinting.py b/policyengine_us_data/calibration/local_h5/fingerprinting.py new file mode 100644 index 000000000..8f401e582 --- /dev/null +++ b/policyengine_us_data/calibration/local_h5/fingerprinting.py @@ -0,0 +1,254 @@ +"""Coordinator-owned provenance and resumability logic for local H5 publication.""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal, Mapping + +from .geography_loader import CalibrationGeographyLoader + +FingerprintScope = Literal["regional", "national"] + + +@dataclass(frozen=True) +class PublishingInputBundle: + """File-system and run metadata needed to publish one H5 scope.""" + + weights_path: Path + source_dataset_path: Path + target_db_path: Path | None + exact_geography_path: Path | None + calibration_package_path: Path | None + run_config_path: Path | None + run_id: str + version: str + n_clones: int | None + seed: int + legacy_blocks_path: Path | None = None + + +@dataclass(frozen=True) +class ArtifactIdentity: + """Stable identity for one input artifact used by traceability and resume.""" + + logical_name: str + path: Path | None + sha256: str | None + size_bytes: int | None = None + metadata: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class TraceabilityBundle: + """Full provenance record for one publish scope.""" + + scope: FingerprintScope + weights: ArtifactIdentity + source_dataset: ArtifactIdentity + exact_geography: ArtifactIdentity | None = None + target_db: ArtifactIdentity | None = None + calibration_package: ArtifactIdentity | None = None + run_config: ArtifactIdentity | None = None + code_version: Mapping[str, Any] = field(default_factory=dict) + model_build: Mapping[str, Any] = field(default_factory=dict) + metadata: Mapping[str, Any] = field(default_factory=dict) + + def resumability_material(self) -> Mapping[str, Any]: + """Return the normalized subset that controls staged-output validity.""" + + geography_sha = None + if self.exact_geography is not None: + geography_sha = self.exact_geography.metadata.get("canonical_sha256") + if geography_sha is None: + geography_sha = self.exact_geography.sha256 + + return { + "scope": self.scope, + "weights_sha256": self.weights.sha256, + "source_dataset_sha256": self.source_dataset.sha256, + "exact_geography_sha256": geography_sha, + "target_db_sha256": ( + self.target_db.sha256 if self.target_db is not None else None + ), + "n_clones": self.metadata.get("n_clones"), + "seed": self.metadata.get("seed"), + "policyengine_us_locked_version": self.model_build.get("locked_version"), + "policyengine_us_git_commit": self.model_build.get("git_commit"), + } + + +class FingerprintingService: + """Build traceability bundles and derive scope fingerprints from them.""" + + def __init__( + self, + *, + geography_loader: CalibrationGeographyLoader | None = None, + ) -> None: + self._geography_loader = geography_loader or CalibrationGeographyLoader() + + def build_traceability( + self, + *, + inputs: PublishingInputBundle, + scope: FingerprintScope, + ) -> TraceabilityBundle: + """Build a traceability bundle from current publish inputs.""" + + run_config_payload = self._load_json(inputs.run_config_path) + return TraceabilityBundle( + scope=scope, + weights=self._build_artifact_identity("weights", inputs.weights_path), + source_dataset=self._build_artifact_identity( + "source_dataset", + inputs.source_dataset_path, + ), + exact_geography=self._build_geography_identity(inputs), + target_db=self._build_optional_artifact_identity( + "target_db", + inputs.target_db_path, + ), + calibration_package=self._build_optional_artifact_identity( + "calibration_package", + inputs.calibration_package_path, + ), + run_config=self._build_optional_artifact_identity( + "run_config", + inputs.run_config_path, + ), + code_version=self._extract_code_version(run_config_payload), + model_build=self._extract_model_build(run_config_payload), + metadata={ + "run_id": inputs.run_id, + "version": inputs.version, + "n_clones": inputs.n_clones, + "seed": inputs.seed, + }, + ) + + def compute_scope_fingerprint(self, traceability: TraceabilityBundle) -> str: + """Hash normalized resumability material into a short scope fingerprint.""" + + payload = json.dumps( + traceability.resumability_material(), + sort_keys=True, + separators=(",", ":"), + ).encode() + return hashlib.sha256(payload).hexdigest()[:16] + + def _build_artifact_identity( + self, + logical_name: str, + path: Path, + *, + metadata: Mapping[str, Any] | None = None, + ) -> ArtifactIdentity: + actual_path = Path(path) + if not actual_path.exists(): + raise FileNotFoundError(f"Expected {logical_name} artifact at {actual_path}") + return ArtifactIdentity( + logical_name=logical_name, + path=actual_path, + sha256=self._sha256_file(actual_path), + size_bytes=actual_path.stat().st_size, + metadata=dict(metadata or {}), + ) + + def _build_optional_artifact_identity( + self, + logical_name: str, + path: Path | None, + ) -> ArtifactIdentity | None: + if path is None: + return None + actual_path = Path(path) + if not actual_path.exists(): + return None + return self._build_artifact_identity(logical_name, actual_path) + + def _build_geography_identity( + self, + inputs: PublishingInputBundle, + ) -> ArtifactIdentity | None: + resolved = self._geography_loader.resolve_source( + weights_path=inputs.weights_path, + geography_path=inputs.exact_geography_path, + blocks_path=inputs.legacy_blocks_path, + calibration_package_path=inputs.calibration_package_path, + ) + if resolved is None: + return None + + metadata = { + "source_kind": resolved.kind, + "canonical_sha256": self._geography_loader.compute_canonical_checksum( + weights_path=inputs.weights_path, + n_records=self._infer_n_records( + weights_path=inputs.weights_path, + source_dataset_path=inputs.source_dataset_path, + n_clones=inputs.n_clones, + ), + n_clones=inputs.n_clones, + geography_path=inputs.exact_geography_path, + blocks_path=inputs.legacy_blocks_path, + calibration_package_path=inputs.calibration_package_path, + ), + } + return self._build_artifact_identity( + "exact_geography", + resolved.path, + metadata=metadata, + ) + + def _extract_code_version(self, run_config_payload: Mapping[str, Any]) -> dict[str, Any]: + return { + "git_commit": run_config_payload.get("git_commit"), + "git_branch": run_config_payload.get("git_branch"), + "git_dirty": run_config_payload.get("git_dirty"), + } + + def _extract_model_build(self, run_config_payload: Mapping[str, Any]) -> dict[str, Any]: + return { + "locked_version": run_config_payload.get("package_version"), + "git_commit": run_config_payload.get("git_commit"), + } + + def _load_json(self, path: Path | None) -> Mapping[str, Any]: + if path is None: + return {} + actual_path = Path(path) + if not actual_path.exists(): + return {} + with open(actual_path) as handle: + return json.load(handle) + + def _sha256_file(self, path: Path) -> str: + digest = hashlib.sha256() + with open(path, "rb") as handle: + for chunk in iter(lambda: handle.read(1 << 20), b""): + digest.update(chunk) + return f"sha256:{digest.hexdigest()}" + + def _infer_n_records( + self, + *, + weights_path: Path, + source_dataset_path: Path, + n_clones: int | None, + ) -> int: + if n_clones is not None: + import numpy as np + + weights = np.load(weights_path, mmap_mode="r") + if len(weights) % n_clones == 0: + return int(len(weights) // n_clones) + + from policyengine_us import Microsimulation + + simulation = Microsimulation(dataset=str(source_dataset_path)) + return int( + len(simulation.calculate("household_id", map_to="household").values) + ) diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index b1946a8f3..785fbafc8 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -11,12 +11,15 @@ import json import shutil - import numpy as np from pathlib import Path from typing import List, Optional from policyengine_us import Microsimulation +from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintingService, + PublishingInputBundle, +) from policyengine_us_data.calibration.local_h5.geography_loader import ( CalibrationGeographyLoader, ) @@ -48,8 +51,6 @@ META_FILE = WORK_DIR / "checkpoint_meta.json" - - def compute_input_fingerprint( weights_path: Path, dataset_path: Path, @@ -57,50 +58,33 @@ def compute_input_fingerprint( seed: int = 42, geography_path: Optional[Path] = None, blocks_path: Optional[Path] = None, + target_db_path: Optional[Path] = None, + run_config_path: Optional[Path] = None, calibration_package_path: Optional[Path] = None, + scope: str = "regional", ) -> str: - import hashlib - - def _update_hash_from_file(h: "hashlib._Hash", path: Path) -> None: - with open(path, "rb") as f: - while chunk := f.read(8192): - h.update(chunk) - - def _infer_n_records() -> int: - if n_clones is not None: - weights = np.load(weights_path, mmap_mode="r") - if len(weights) % n_clones == 0: - return len(weights) // n_clones - sim = Microsimulation(dataset=str(dataset_path)) - return len(sim.calculate("household_id", map_to="household").values) - - loader = CalibrationGeographyLoader() - h = hashlib.sha256() - for p in [weights_path, dataset_path]: - _update_hash_from_file(h, p) - - resolved = loader.resolve_source( - weights_path=weights_path, - geography_path=geography_path, - blocks_path=blocks_path, - calibration_package_path=calibration_package_path, + service = FingerprintingService() + inputs = PublishingInputBundle( + weights_path=Path(weights_path), + source_dataset_path=Path(dataset_path), + target_db_path=Path(target_db_path) if target_db_path is not None else None, + exact_geography_path=( + Path(geography_path) if geography_path is not None else None + ), + calibration_package_path=( + Path(calibration_package_path) + if calibration_package_path is not None + else None + ), + run_config_path=Path(run_config_path) if run_config_path is not None else None, + run_id="", + version="", + n_clones=n_clones, + seed=seed, + legacy_blocks_path=Path(blocks_path) if blocks_path is not None else None, ) - if resolved is not None: - n_records = _infer_n_records() - h.update(f"geography_source:{resolved.kind}".encode()) - h.update( - loader.compute_canonical_checksum( - weights_path=weights_path, - n_records=n_records, - n_clones=n_clones, - geography_path=geography_path, - blocks_path=blocks_path, - calibration_package_path=calibration_package_path, - ).encode() - ) - else: - h.update(f"legacy_regeneration:{n_clones}:{seed}".encode()) - return h.hexdigest()[:16] + traceability = service.build_traceability(inputs=inputs, scope=scope) + return service.compute_scope_fingerprint(traceability) def load_calibration_geography( diff --git a/tests/integration/local_h5/fixtures.py b/tests/integration/local_h5/fixtures.py new file mode 100644 index 000000000..3edf0f020 --- /dev/null +++ b/tests/integration/local_h5/fixtures.py @@ -0,0 +1,203 @@ +"""Shared tiny-artifact fixtures for local H5 integration tests.""" + +from __future__ import annotations + +import json +import pickle +import shutil +import sqlite3 +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +import numpy as np + +from policyengine_us_data.calibration.clone_and_assign import ( + GeographyAssignment, + save_geography, +) +from policyengine_us_data.calibration.local_h5.requests import ( + AreaBuildRequest, + AreaFilter, +) + +FIXTURE_DATASET_PATH = Path(__file__).resolve().parents[1] / "test_fixture_50hh.h5" +DISTRICT_GEOID = "3701" +COUNTY_FIPS = "37183" +STATE_FIPS = 37 +N_CLONES = 1 +SEED = 42 +VERSION = "0.0.0" + + +@dataclass(frozen=True) +class LocalH5Artifacts: + dataset_path: Path + weights_path: Path + db_path: Path + run_config_path: Path + geography_path: Path + calibration_package_path: Path + geography: GeographyAssignment + n_records: int + n_clones: int + + +@lru_cache(maxsize=1) +def fixture_household_count() -> int: + from policyengine_us import Microsimulation + + sim = Microsimulation(dataset=str(FIXTURE_DATASET_PATH)) + try: + return int(len(sim.calculate("household_id", map_to="household").values)) + finally: + del sim + + +def base_geography(*, n_records: int, n_clones: int = N_CLONES) -> GeographyAssignment: + total_rows = n_records * n_clones + block_geoids = np.array( + [f"{COUNTY_FIPS}{i:06d}{i:04d}"[:15] for i in range(total_rows)], + dtype="U15", + ) + return GeographyAssignment( + block_geoid=block_geoids, + cd_geoid=np.full(total_rows, DISTRICT_GEOID, dtype="U4"), + county_fips=np.full(total_rows, COUNTY_FIPS, dtype="U5"), + state_fips=np.full(total_rows, STATE_FIPS, dtype=np.int32), + n_records=n_records, + n_clones=n_clones, + ) + + +def seed_local_h5_artifacts( + tmp_path: Path, + *, + n_clones: int = N_CLONES, +) -> LocalH5Artifacts: + artifact_dir = tmp_path / "artifacts" + if artifact_dir.exists(): + shutil.rmtree(artifact_dir) + artifact_dir.mkdir(parents=True, exist_ok=True) + + dataset_path = artifact_dir / "source.h5" + weights_path = artifact_dir / "calibration_weights.npy" + db_path = artifact_dir / "policy_data.db" + run_config_path = artifact_dir / "unified_run_config.json" + geography_path = artifact_dir / "geography_assignment.npz" + calibration_package_path = artifact_dir / "calibration_package.pkl" + + shutil.copy2(FIXTURE_DATASET_PATH, dataset_path) + n_records = fixture_household_count() + np.save(weights_path, np.ones(n_records * n_clones, dtype=np.float32)) + + geography = base_geography(n_records=n_records, n_clones=n_clones) + save_geography(geography, geography_path) + + with open(calibration_package_path, "wb") as handle: + pickle.dump( + { + "block_geoid": geography.block_geoid, + "cd_geoid": geography.cd_geoid, + "metadata": { + "git_commit": "deadbeefcafebabe", + "git_branch": "main", + "git_dirty": False, + "package_version": VERSION, + }, + }, + handle, + protocol=pickle.HIGHEST_PROTOCOL, + ) + + conn = sqlite3.connect(db_path) + try: + conn.execute( + """ + CREATE TABLE stratum_constraints ( + stratum_id INTEGER, + constraint_variable TEXT, + value TEXT + ) + """ + ) + conn.execute( + """ + INSERT INTO stratum_constraints (stratum_id, constraint_variable, value) + VALUES (?, ?, ?) + """, + (1, "congressional_district_geoid", DISTRICT_GEOID), + ) + conn.commit() + finally: + conn.close() + + run_config_path.write_text( + json.dumps( + { + "git_commit": "deadbeefcafebabe", + "git_branch": "main", + "git_dirty": False, + "package_version": VERSION, + } + ) + ) + + return LocalH5Artifacts( + dataset_path=dataset_path, + weights_path=weights_path, + db_path=db_path, + run_config_path=run_config_path, + geography_path=geography_path, + calibration_package_path=calibration_package_path, + geography=geography, + n_records=n_records, + n_clones=n_clones, + ) + + +def build_request( + area_type: str, *, geography: GeographyAssignment +) -> AreaBuildRequest: + if area_type == "district": + return AreaBuildRequest( + area_type="district", + area_id="NC-01", + display_name="NC-01", + output_relative_path="districts/NC-01.h5", + filters=( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=(DISTRICT_GEOID,), + ), + ), + validation_geo_level="district", + validation_geographic_ids=(DISTRICT_GEOID,), + ) + if area_type == "state": + return AreaBuildRequest( + area_type="state", + area_id="NC", + display_name="NC", + output_relative_path="states/NC.h5", + filters=( + AreaFilter( + geography_field="cd_geoid", + op="in", + value=(DISTRICT_GEOID,), + ), + ), + validation_geo_level="state", + validation_geographic_ids=(str(STATE_FIPS),), + ) + if area_type == "national": + return AreaBuildRequest( + area_type="national", + area_id="US", + display_name="US", + output_relative_path="national/US.h5", + validation_geo_level="national", + validation_geographic_ids=("US",), + ) + raise ValueError(f"Unsupported area_type for test fixture: {area_type}") diff --git a/tests/integration/local_h5/test_modal_local_area_traceability.py b/tests/integration/local_h5/test_modal_local_area_traceability.py new file mode 100644 index 000000000..13d86ad30 --- /dev/null +++ b/tests/integration/local_h5/test_modal_local_area_traceability.py @@ -0,0 +1,64 @@ +from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintingService, +) + +from tests.integration.local_h5.fixtures import SEED, VERSION, seed_local_h5_artifacts +from tests.unit.fixtures.test_modal_local_area import load_local_area_module + + +def test_local_area_helpers_match_publish_traceability_contract(tmp_path): + local_area = load_local_area_module(stub_policyengine=False) + artifacts = seed_local_h5_artifacts(tmp_path) + + inputs = local_area._build_publishing_input_bundle( + weights_path=artifacts.weights_path, + dataset_path=artifacts.dataset_path, + db_path=artifacts.db_path, + geography_path=artifacts.geography_path, + calibration_package_path=artifacts.calibration_package_path, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + + helper_fingerprint = local_area._resolve_scope_fingerprint( + inputs=inputs, + scope="regional", + ) + service = FingerprintingService() + service_fingerprint = service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope="regional") + ) + + assert helper_fingerprint == service_fingerprint + + +def test_local_area_scope_helper_distinguishes_regional_and_national(tmp_path): + local_area = load_local_area_module(stub_policyengine=False) + artifacts = seed_local_h5_artifacts(tmp_path) + + inputs = local_area._build_publishing_input_bundle( + weights_path=artifacts.weights_path, + dataset_path=artifacts.dataset_path, + db_path=artifacts.db_path, + geography_path=artifacts.geography_path, + calibration_package_path=artifacts.calibration_package_path, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + + regional = local_area._resolve_scope_fingerprint( + inputs=inputs, + scope="regional", + ) + national = local_area._resolve_scope_fingerprint( + inputs=inputs, + scope="national", + ) + + assert regional != national diff --git a/tests/integration/local_h5/test_traceability_contract.py b/tests/integration/local_h5/test_traceability_contract.py new file mode 100644 index 000000000..65fa0b678 --- /dev/null +++ b/tests/integration/local_h5/test_traceability_contract.py @@ -0,0 +1,88 @@ +from policyengine_us_data.calibration.local_h5.fingerprinting import ( + FingerprintingService, + PublishingInputBundle, +) + +from tests.integration.local_h5.fixtures import SEED, VERSION, seed_local_h5_artifacts + + +def _fingerprint_for(*, inputs, scope: str = "regional") -> str: + service = FingerprintingService() + return service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope=scope) + ) + + +def test_saved_geography_bundle_builds_traceability_with_stable_fingerprint(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path) + inputs = PublishingInputBundle( + weights_path=artifacts.weights_path, + source_dataset_path=artifacts.dataset_path, + target_db_path=artifacts.db_path, + exact_geography_path=artifacts.geography_path, + calibration_package_path=None, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + + first = _fingerprint_for(inputs=inputs) + second = _fingerprint_for(inputs=inputs) + + assert first == second + + +def test_package_geography_bundle_builds_traceability_with_stable_fingerprint(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path) + inputs = PublishingInputBundle( + weights_path=artifacts.weights_path, + source_dataset_path=artifacts.dataset_path, + target_db_path=artifacts.db_path, + exact_geography_path=None, + calibration_package_path=artifacts.calibration_package_path, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + + first = _fingerprint_for(inputs=inputs) + second = _fingerprint_for(inputs=inputs) + + assert first == second + + +def test_saved_and_package_geography_share_the_same_resumability_identity(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path) + saved_inputs = PublishingInputBundle( + weights_path=artifacts.weights_path, + source_dataset_path=artifacts.dataset_path, + target_db_path=artifacts.db_path, + exact_geography_path=artifacts.geography_path, + calibration_package_path=None, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + package_inputs = PublishingInputBundle( + weights_path=artifacts.weights_path, + source_dataset_path=artifacts.dataset_path, + target_db_path=artifacts.db_path, + exact_geography_path=None, + calibration_package_path=artifacts.calibration_package_path, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version=VERSION, + n_clones=artifacts.n_clones, + seed=SEED, + ) + + saved_fingerprint = _fingerprint_for(inputs=saved_inputs) + package_fingerprint = _fingerprint_for(inputs=package_inputs) + + assert saved_fingerprint == package_fingerprint diff --git a/tests/integration/local_h5/test_worker_script_tiny_fixture.py b/tests/integration/local_h5/test_worker_script_tiny_fixture.py new file mode 100644 index 000000000..12b6a0426 --- /dev/null +++ b/tests/integration/local_h5/test_worker_script_tiny_fixture.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + +import pytest + +from tests.integration.local_h5.fixtures import ( + build_request, + seed_local_h5_artifacts, +) + +pytest.importorskip("scipy") +pytest.importorskip("spm_calculator") + + +def _run_worker( + *, + request, + artifacts, + output_dir: Path, + use_saved_geography: bool = False, + use_package_geography: bool = False, +) -> dict: + cmd = [ + sys.executable, + "-m", + "modal_app.worker_script", + "--requests-json", + json.dumps([request.to_dict()]), + "--weights-path", + str(artifacts.weights_path), + "--dataset-path", + str(artifacts.dataset_path), + "--db-path", + str(artifacts.db_path), + "--output-dir", + str(output_dir), + "--n-clones", + str(artifacts.n_clones), + "--no-validate", + ] + if use_saved_geography: + cmd.extend(["--geography-path", str(artifacts.geography_path)]) + if use_package_geography: + cmd.extend( + [ + "--calibration-package-path", + str(artifacts.calibration_package_path), + ] + ) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True, + ) + return json.loads(result.stdout) + + +def test_worker_builds_district_h5_from_saved_geography(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path / "district") + request = build_request("district", geography=artifacts.geography) + output_dir = tmp_path / "district-out" + + result = _run_worker( + request=request, + artifacts=artifacts, + output_dir=output_dir, + use_saved_geography=True, + ) + + assert result["failed"] == [] + assert result["errors"] == [] + assert result["completed"] == [f"district:{request.area_id}"] + assert (output_dir / request.output_relative_path).exists() + + +def test_worker_builds_state_h5_from_package_geography(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path / "state") + request = build_request("state", geography=artifacts.geography) + output_dir = tmp_path / "state-out" + + result = _run_worker( + request=request, + artifacts=artifacts, + output_dir=output_dir, + use_package_geography=True, + ) + + assert result["failed"] == [] + assert result["errors"] == [] + assert result["completed"] == [f"state:{request.area_id}"] + assert (output_dir / request.output_relative_path).exists() + + +def test_worker_builds_national_h5_from_package_geography(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path / "national") + request = build_request("national", geography=artifacts.geography) + output_dir = tmp_path / "national-out" + + result = _run_worker( + request=request, + artifacts=artifacts, + output_dir=output_dir, + use_package_geography=True, + ) + + assert result["failed"] == [] + assert result["errors"] == [] + assert result["completed"] == ["national:US"] + assert (output_dir / request.output_relative_path).exists() diff --git a/tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py b/tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py new file mode 100644 index 000000000..2ecffd000 --- /dev/null +++ b/tests/unit/calibration/fixtures/test_local_h5_fingerprinting.py @@ -0,0 +1,117 @@ +"""Fixture helpers for ``test_local_h5_fingerprinting.py``.""" + +from __future__ import annotations + +import importlib +import json +from pathlib import Path + +import h5py +import numpy as np + +from tests.unit.calibration.fixtures.test_local_h5_geography_loader import ( + write_saved_geography, +) + +__test__ = False + +_FINGERPRINTING_EXPORTS = None + + +def load_fingerprinting_exports(): + """Load the fingerprinting module without replacing shared package modules.""" + + global _FINGERPRINTING_EXPORTS + if _FINGERPRINTING_EXPORTS is not None: + return _FINGERPRINTING_EXPORTS + + module = importlib.import_module( + "policyengine_us_data.calibration.local_h5.fingerprinting" + ) + _FINGERPRINTING_EXPORTS = { + "module": module, + "ArtifactIdentity": module.ArtifactIdentity, + "FingerprintingService": module.FingerprintingService, + "PublishingInputBundle": module.PublishingInputBundle, + "TraceabilityBundle": module.TraceabilityBundle, + } + return _FINGERPRINTING_EXPORTS + + +def write_source_dataset( + path: Path, + *, + n_records: int, + person_records: int | None = None, +) -> None: + """Write a minimal HDF5 dataset with a ``person`` entity.""" + + person_count = person_records if person_records is not None else n_records + with h5py.File(path, "w") as handle: + person = handle.create_group("person") + person.create_dataset("person_id", data=np.arange(person_count, dtype=np.int32)) + + +def write_run_config(path: Path, *, package_version: str = "1.0.0") -> None: + """Write a minimal run-config payload with provenance fields.""" + + payload = { + "git_commit": "deadbeefcafebabe", + "git_branch": "main", + "git_dirty": False, + "package_version": package_version, + } + path.write_text(json.dumps(payload)) + + +def write_artifact_file(path: Path, content: bytes) -> None: + """Write one small binary artifact for traceability tests.""" + + path.write_bytes(content) + + +def make_publishing_inputs( + bundle_cls, + *, + tmp_path: Path, + n_records: int = 2, + person_records: int | None = None, + n_clones: int = 2, + seed: int = 42, + package_version: str = "1.0.0", +): + """Create a fully-populated publishing input bundle for tests.""" + + tmp_path.mkdir(parents=True, exist_ok=True) + weights_path = tmp_path / "calibration_weights.npy" + dataset_path = tmp_path / "source.h5" + db_path = tmp_path / "policy_data.db" + geography_path = tmp_path / "geography_assignment.npz" + run_config_path = tmp_path / "unified_run_config.json" + + np.save(weights_path, np.arange(n_records * n_clones, dtype=float) + 1.0) + write_source_dataset( + dataset_path, + n_records=n_records, + person_records=person_records, + ) + write_artifact_file(db_path, b"fake-db") + write_saved_geography( + geography_path, + n_records=n_records, + n_clones=n_clones, + ) + write_run_config(run_config_path, package_version=package_version) + + return bundle_cls( + weights_path=weights_path, + source_dataset_path=dataset_path, + target_db_path=db_path, + exact_geography_path=geography_path, + calibration_package_path=None, + run_config_path=run_config_path, + run_id="run-123", + version="1.2.3", + n_clones=n_clones, + seed=seed, + ) diff --git a/tests/unit/calibration/test_local_h5_fingerprinting.py b/tests/unit/calibration/test_local_h5_fingerprinting.py new file mode 100644 index 000000000..08d2b593b --- /dev/null +++ b/tests/unit/calibration/test_local_h5_fingerprinting.py @@ -0,0 +1,149 @@ +from tests.unit.calibration.fixtures.test_local_h5_fingerprinting import ( + load_fingerprinting_exports, + make_publishing_inputs, +) + + +exports = load_fingerprinting_exports() +FingerprintingService = exports["FingerprintingService"] +PublishingInputBundle = exports["PublishingInputBundle"] + + +def test_build_traceability_captures_artifact_identity_and_metadata(tmp_path): + inputs = make_publishing_inputs(PublishingInputBundle, tmp_path=tmp_path) + + service = FingerprintingService() + traceability = service.build_traceability(inputs=inputs, scope="regional") + + assert traceability.scope == "regional" + assert traceability.weights.path == inputs.weights_path + assert traceability.weights.sha256.startswith("sha256:") + assert traceability.source_dataset.sha256.startswith("sha256:") + assert traceability.exact_geography is not None + assert traceability.exact_geography.metadata["source_kind"] == "saved_geography" + assert traceability.exact_geography.metadata["canonical_sha256"].startswith( + "sha256:" + ) + assert traceability.target_db is not None + assert traceability.model_build["locked_version"] == "1.0.0" + assert traceability.metadata["n_clones"] == 2 + assert traceability.metadata["seed"] == 42 + + +def test_scope_fingerprint_differs_between_regional_and_national(tmp_path): + inputs = make_publishing_inputs(PublishingInputBundle, tmp_path=tmp_path) + + service = FingerprintingService() + regional = service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope="regional") + ) + national = service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope="national") + ) + + assert regional != national + + +def test_scope_fingerprint_is_stable_for_identical_inputs(tmp_path): + inputs = make_publishing_inputs(PublishingInputBundle, tmp_path=tmp_path) + + service = FingerprintingService() + first = service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope="regional") + ) + second = service.compute_scope_fingerprint( + service.build_traceability(inputs=inputs, scope="regional") + ) + + assert first == second + + +def test_scope_fingerprint_changes_when_relevant_provenance_changes(tmp_path): + first_inputs = make_publishing_inputs( + PublishingInputBundle, + tmp_path=tmp_path / "first", + ) + second_inputs = make_publishing_inputs( + PublishingInputBundle, + tmp_path=tmp_path / "second", + ) + second_inputs.target_db_path.write_bytes(b"changed-db") + + service = FingerprintingService() + first = service.compute_scope_fingerprint( + service.build_traceability(inputs=first_inputs, scope="regional") + ) + second = service.compute_scope_fingerprint( + service.build_traceability(inputs=second_inputs, scope="regional") + ) + + assert first != second + + +def test_traceability_uses_weight_derived_household_count_for_geography(tmp_path): + inputs = make_publishing_inputs( + PublishingInputBundle, + tmp_path=tmp_path, + n_records=2, + person_records=5, + n_clones=2, + ) + + service = FingerprintingService() + traceability = service.build_traceability(inputs=inputs, scope="regional") + + assert traceability.exact_geography is not None + assert traceability.exact_geography.metadata["canonical_sha256"].startswith( + "sha256:" + ) + + +def test_resumability_material_prefers_canonical_geography_checksum(tmp_path): + inputs = make_publishing_inputs(PublishingInputBundle, tmp_path=tmp_path) + + service = FingerprintingService() + traceability = service.build_traceability(inputs=inputs, scope="regional") + resumability = traceability.resumability_material() + + assert traceability.exact_geography is not None + assert ( + resumability["exact_geography_sha256"] + == traceability.exact_geography.metadata["canonical_sha256"] + ) + + +def test_traceability_handles_missing_optional_artifacts(tmp_path): + inputs = make_publishing_inputs(PublishingInputBundle, tmp_path=tmp_path) + standalone_weights_path = tmp_path / "standalone" / "weights.npy" + standalone_weights_path.parent.mkdir(parents=True, exist_ok=True) + standalone_weights_path.write_bytes(inputs.weights_path.read_bytes()) + inputs = PublishingInputBundle( + weights_path=standalone_weights_path, + source_dataset_path=inputs.source_dataset_path, + target_db_path=None, + exact_geography_path=None, + calibration_package_path=None, + run_config_path=None, + run_id=inputs.run_id, + version=inputs.version, + n_clones=inputs.n_clones, + seed=inputs.seed, + legacy_blocks_path=None, + ) + + service = FingerprintingService() + traceability = service.build_traceability(inputs=inputs, scope="regional") + + assert traceability.target_db is None + assert traceability.exact_geography is None + assert traceability.calibration_package is None + assert traceability.run_config is None + assert traceability.code_version == { + "git_commit": None, + "git_branch": None, + "git_dirty": None, + } + assert traceability.model_build == { + "locked_version": None, + "git_commit": None, + } diff --git a/tests/unit/fixtures/test_modal_local_area.py b/tests/unit/fixtures/test_modal_local_area.py index 377e879ae..db9d0e621 100644 --- a/tests/unit/fixtures/test_modal_local_area.py +++ b/tests/unit/fixtures/test_modal_local_area.py @@ -31,19 +31,10 @@ def _patched_module_registry(overrides: dict[str, ModuleType]): sys.modules[name] = module -def load_local_area_module(): +def load_local_area_module(*, stub_policyengine: bool = True): """Import `modal_app.local_area` with scoped fake Modal dependencies.""" fake_modal = ModuleType("modal") - fake_policyengine = ModuleType("policyengine_us_data") - fake_calibration = ModuleType("policyengine_us_data.calibration") - fake_local_h5 = ModuleType("policyengine_us_data.calibration.local_h5") - fake_partitioning = ModuleType( - "policyengine_us_data.calibration.local_h5.partitioning" - ) - fake_policyengine.__path__ = [] - fake_calibration.__path__ = [] - fake_local_h5.__path__ = [] class _FakeApp: def __init__(self, *args, **kwargs): @@ -70,19 +61,50 @@ def decorator(func): fake_resilience = ModuleType("modal_app.resilience") fake_resilience.reconcile_run_dir_fingerprint = lambda *args, **kwargs: None - fake_partitioning.partition_weighted_work_items = lambda *args, **kwargs: [] - - with _patched_module_registry( - { - "modal": fake_modal, - "modal_app.images": fake_images, - "modal_app.resilience": fake_resilience, - "policyengine_us_data": fake_policyengine, - "policyengine_us_data.calibration": fake_calibration, - "policyengine_us_data.calibration.local_h5": fake_local_h5, - "policyengine_us_data.calibration.local_h5.partitioning": ( - fake_partitioning - ), - } - ): + + overrides = { + "modal": fake_modal, + "modal_app.images": fake_images, + "modal_app.resilience": fake_resilience, + } + + if stub_policyengine: + fake_policyengine = ModuleType("policyengine_us_data") + fake_calibration = ModuleType("policyengine_us_data.calibration") + fake_local_h5 = ModuleType("policyengine_us_data.calibration.local_h5") + fake_partitioning = ModuleType( + "policyengine_us_data.calibration.local_h5.partitioning" + ) + fake_fingerprinting = ModuleType( + "policyengine_us_data.calibration.local_h5.fingerprinting" + ) + fake_policyengine.__path__ = [] + fake_calibration.__path__ = [] + fake_local_h5.__path__ = [] + fake_partitioning.partition_weighted_work_items = lambda *args, **kwargs: [] + fake_fingerprinting.PublishingInputBundle = object + + class _FakeFingerprintingService: + def build_traceability(self, *args, **kwargs): + return object() + + def compute_scope_fingerprint(self, *args, **kwargs): + return "fake-fingerprint" + + fake_fingerprinting.FingerprintingService = _FakeFingerprintingService + overrides.update( + { + "policyengine_us_data": fake_policyengine, + "policyengine_us_data.calibration": fake_calibration, + "policyengine_us_data.calibration.local_h5": fake_local_h5, + "policyengine_us_data.calibration.local_h5.fingerprinting": ( + fake_fingerprinting + ), + "policyengine_us_data.calibration.local_h5.partitioning": ( + fake_partitioning + ), + } + ) + + with _patched_module_registry(overrides): return importlib.import_module("modal_app.local_area") diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index 0e3cd9fd6..e8128db71 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -1,3 +1,5 @@ +from pathlib import Path + from tests.unit.fixtures.test_modal_local_area import load_local_area_module @@ -28,3 +30,164 @@ def test_build_promote_publish_script_finalizes_complete_release(): assert "should_finalize_local_area_release" in script assert "create_tag=should_finalize" in script assert "upload_manifest(" in script + + +def test_build_publishing_input_bundle_preserves_traceability_inputs(): + local_area = load_local_area_module(stub_policyengine=False) + + bundle = local_area._build_publishing_input_bundle( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + db_path=Path("/tmp/policy_data.db"), + geography_path=Path("/tmp/geography_assignment.npz"), + calibration_package_path=Path("/tmp/calibration_package.pkl"), + run_config_path=Path("/tmp/unified_run_config.json"), + run_id="run-123", + version="1.2.3", + n_clones=4, + seed=42, + legacy_blocks_path=Path("/tmp/stacked_blocks.npy"), + ) + + assert bundle.weights_path == Path("/tmp/calibration_weights.npy") + assert bundle.source_dataset_path == Path("/tmp/source.h5") + assert bundle.target_db_path == Path("/tmp/policy_data.db") + assert bundle.exact_geography_path == Path("/tmp/geography_assignment.npz") + assert bundle.calibration_package_path == Path("/tmp/calibration_package.pkl") + assert bundle.run_config_path == Path("/tmp/unified_run_config.json") + assert bundle.run_id == "run-123" + assert bundle.version == "1.2.3" + assert bundle.n_clones == 4 + assert bundle.seed == 42 + assert bundle.legacy_blocks_path == Path("/tmp/stacked_blocks.npy") + + +def test_resolve_scope_fingerprint_computes_when_no_pin(monkeypatch): + local_area = load_local_area_module(stub_policyengine=False) + + seen = {} + + class FakeFingerprintingService: + def build_traceability(self, *, inputs, scope): + seen["inputs"] = inputs + seen["scope"] = scope + return {"scope": scope, "run_id": inputs.run_id} + + def compute_scope_fingerprint(self, traceability): + seen["traceability"] = traceability + return "computed-fingerprint" + + monkeypatch.setattr( + local_area, + "FingerprintingService", + FakeFingerprintingService, + ) + + bundle = local_area._build_publishing_input_bundle( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + db_path=None, + geography_path=None, + calibration_package_path=None, + run_config_path=None, + run_id="run-123", + version="1.2.3", + n_clones=2, + seed=42, + ) + + fingerprint = local_area._resolve_scope_fingerprint( + inputs=bundle, + scope="regional", + ) + + assert fingerprint == "computed-fingerprint" + assert seen["inputs"] == bundle + assert seen["scope"] == "regional" + assert seen["traceability"] == {"scope": "regional", "run_id": "run-123"} + + +def test_resolve_scope_fingerprint_preserves_matching_pin(monkeypatch, capsys): + local_area = load_local_area_module(stub_policyengine=False) + + class FakeFingerprintingService: + def build_traceability(self, *, inputs, scope): + return scope + + def compute_scope_fingerprint(self, traceability): + return "pinned-fingerprint" + + monkeypatch.setattr( + local_area, + "FingerprintingService", + FakeFingerprintingService, + ) + + bundle = local_area._build_publishing_input_bundle( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + db_path=None, + geography_path=None, + calibration_package_path=None, + run_config_path=None, + run_id="run-123", + version="1.2.3", + n_clones=2, + seed=42, + ) + + fingerprint = local_area._resolve_scope_fingerprint( + inputs=bundle, + scope="regional", + expected_fingerprint="pinned-fingerprint", + ) + + captured = capsys.readouterr() + assert fingerprint == "pinned-fingerprint" + assert "Using pinned fingerprint from pipeline" in captured.out + + +def test_resolve_scope_fingerprint_warns_and_preserves_mismatched_pin( + monkeypatch, capsys +): + local_area = load_local_area_module(stub_policyengine=False) + + class FakeFingerprintingService: + def build_traceability(self, *, inputs, scope): + return scope + + def compute_scope_fingerprint(self, traceability): + return "computed-fingerprint" + + monkeypatch.setattr( + local_area, + "FingerprintingService", + FakeFingerprintingService, + ) + + bundle = local_area._build_publishing_input_bundle( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + db_path=None, + geography_path=None, + calibration_package_path=None, + run_config_path=None, + run_id="run-123", + version="1.2.3", + n_clones=2, + seed=42, + ) + + fingerprint = local_area._resolve_scope_fingerprint( + inputs=bundle, + scope="national", + expected_fingerprint="legacy-fingerprint", + ) + + captured = capsys.readouterr() + assert fingerprint == "legacy-fingerprint" + assert "Pinned fingerprint differs from current national scope fingerprint" in ( + captured.out + ) + assert "legacy-fingerprint" in captured.out + assert "computed-fingerprint" in captured.out diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py index 5aaca8a47..5458c53a7 100644 --- a/tests/unit/test_pipeline.py +++ b/tests/unit/test_pipeline.py @@ -8,7 +8,7 @@ modal = pytest.importorskip("modal") -from modal_app.pipeline import ( +from modal_app.pipeline import ( # noqa: E402 RunMetadata, _step_completed, _record_step, @@ -63,6 +63,39 @@ def test_from_dict(self): assert meta.status == "completed" assert meta.step_timings["build_datasets"]["status"] == "completed" + def test_from_dict_maps_legacy_fingerprint_to_regional_scope(self): + meta = RunMetadata.from_dict( + { + "run_id": "test", + "branch": "main", + "sha": "abc12345deadbeef", + "version": "1.72.3", + "start_time": "2026-03-19T12:00:00Z", + "status": "running", + "fingerprint": "legacy-fingerprint", + } + ) + + assert meta.fingerprint == "legacy-fingerprint" + assert meta.regional_fingerprint == "legacy-fingerprint" + + def test_from_dict_keeps_explicit_regional_fingerprint_when_both_present(self): + meta = RunMetadata.from_dict( + { + "run_id": "test", + "branch": "main", + "sha": "abc12345deadbeef", + "version": "1.72.3", + "start_time": "2026-03-19T12:00:00Z", + "status": "running", + "fingerprint": "legacy-fingerprint", + "regional_fingerprint": "regional-fingerprint", + } + ) + + assert meta.fingerprint == "legacy-fingerprint" + assert meta.regional_fingerprint == "regional-fingerprint" + def test_roundtrip(self): meta = RunMetadata( run_id="1.72.3_abc12345_20260319_120000", @@ -79,6 +112,39 @@ def test_roundtrip(self): assert roundtripped.status == meta.status assert roundtripped.error == meta.error + def test_to_dict_keeps_legacy_fingerprint_alias_in_sync(self): + meta = RunMetadata( + run_id="test", + branch="main", + sha="abc", + version="1.0.0", + start_time="now", + status="running", + regional_fingerprint="regional-fp", + ) + + payload = meta.to_dict() + + assert payload["fingerprint"] == "regional-fp" + assert payload["regional_fingerprint"] == "regional-fp" + + def test_to_dict_preserves_distinct_explicit_regional_fingerprint(self): + meta = RunMetadata( + run_id="test", + branch="main", + sha="abc", + version="1.0.0", + start_time="now", + status="running", + fingerprint="legacy-fp", + regional_fingerprint="regional-fp", + ) + + payload = meta.to_dict() + + assert payload["fingerprint"] == "legacy-fp" + assert payload["regional_fingerprint"] == "regional-fp" + def test_step_timings_default_empty(self): meta = RunMetadata( run_id="test",