Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,22 @@
"policyengine_version": "4.10.0",
"us": {
"model_version": "1.500.0",
"data_version": "1.110.12",
"data_artifact_revision": "1.110.12",
"default_dataset": "enhanced_cps_2024",
"default_dataset_uri": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12",
"data_version": "populace-us-2024-test",
"data_artifact_revision": "us-artifact-revision",
"default_dataset": "populace_us_2024",
"default_dataset_uri": "hf://policyengine/populace-us/populace_us_2024.h5@us-artifact-revision",
"dataset_uris": {
"enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12",
"cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12",
"pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12",
"states/UT": "hf://policyengine/policyengine-us-data/states/UT.h5@1.115.5",
},
"dataset_aliases": {
"enhanced_cps": "enhanced_cps_2024",
"enhanced_cps_2024": "enhanced_cps_2024",
"cps": "cps_2023",
"cps_2023": "cps_2023",
"pooled_cps": "pooled_3_year_cps_2023",
"pooled_3_year_cps_2023": "pooled_3_year_cps_2023",
"populace_us_2024": "hf://policyengine/populace-us/populace_us_2024.h5@us-artifact-revision",
},
},
"uk": {
"model_version": "2.66.0",
"data_version": "1.40.3",
"data_artifact_revision": "1.40.3",
"default_dataset": "enhanced_frs_2023_24",
"default_dataset_uri": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3",
"data_version": "populace-uk-2023-test",
"data_artifact_revision": "uk-artifact-revision",
"default_dataset": "populace_uk_2023",
"default_dataset_uri": "hf://policyengine/populace-uk-private/populace_uk_2023.h5@uk-artifact-revision",
"dataset_uris": {
"enhanced_frs_2023_24": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3",
"frs_2023_24": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3",
},
"dataset_aliases": {
"enhanced_frs": "enhanced_frs_2023_24",
"enhanced_frs_2023_24": "enhanced_frs_2023_24",
"frs": "frs_2023_24",
"frs_2023_24": "frs_2023_24",
"populace_uk_2023": "hf://policyengine/populace-uk-private/populace_uk_2023.h5@uk-artifact-revision",
},
},
}
Expand Down Expand Up @@ -102,14 +84,15 @@ def _runtime_dataset_uri(
selected_revision = revision or existing_revision

if dataset_without_revision.startswith("hf://policyengine/"):
if (
selected_revision == country_bundle.get("data_artifact_revision")
and revision is None
):
selected_revision = country_bundle["data_version"]
remainder = dataset_without_revision.removeprefix("hf://policyengine/")
bucket, _, path = remainder.partition("/")
dataset_without_revision = f"gs://{bucket}/{path}"
if bucket.startswith("policyengine-") and "-data" in bucket:
if (
selected_revision == country_bundle.get("data_artifact_revision")
and revision is None
):
selected_revision = country_bundle["data_version"]
dataset_without_revision = f"gs://{bucket}/{path}"

if selected_revision is None and use_bundle_default:
selected_revision = country_bundle["data_version"]
Expand Down Expand Up @@ -140,7 +123,9 @@ def resolve_test_dataset_uri(
)

dataset_name, revision = _split_revision(dataset)
dataset_name = country_bundle["dataset_aliases"].get(dataset_name, dataset_name)
aliases = country_bundle.get("dataset_aliases")
if isinstance(aliases, dict):
dataset_name = aliases.get(dataset_name, dataset_name)
dataset_uri = country_bundle["dataset_uris"].get(dataset_name, dataset_name)
if revision is not None and dataset_uri == dataset_name:
return dataset
Expand Down
4 changes: 2 additions & 2 deletions projects/policyengine-api-simulation/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ dependencies = [
"pydantic-settings (>=2.7.1,<3.0.0)",
"opentelemetry-instrumentation-fastapi (>=0.51b0,<0.52)",
"policyengine-fastapi",
"policyengine==4.18.6",
"policyengine==4.18.7",
"policyengine-core==3.28.0",
"policyengine-uk==2.89.2",
"policyengine-us==1.745.0",
"policyengine-us==1.729.0",
"tables>=3.10.2",
"modal>=0.73.0",
"logfire>=3.0.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@ def _revision_from_dataset_uri(dataset_uri: str | None) -> str | None:
return revision


def _bundle_response_data_version(
*,
country_bundle: dict,
requested_dataset: str | None,
requested_data_version: str | None,
resolved_dataset: str | None,
) -> str | None:
if requested_data_version is not None:
return requested_data_version
if _revision_from_dataset_uri(requested_dataset) is not None:
return _revision_from_dataset_uri(resolved_dataset)
data_version = country_bundle.get("data_version")
if isinstance(data_version, str):
return data_version
return _revision_from_dataset_uri(resolved_dataset)


def _bundle_certified_hf_uri_roots(country_bundle: dict) -> set[str]:
roots: set[str] = set()
default_uri = country_bundle.get("default_dataset_uri")
Expand Down Expand Up @@ -192,6 +209,8 @@ def _resolve_dataset_uri_from_app_bundle(
if not isinstance(country_bundle, dict):
return requested_data

# Older Modal snapshots may contain aliases. Newly published bundle snapshots
# resolve direct .py dataset names through dataset_uris instead.
aliases = country_bundle.get("dataset_aliases")
if not isinstance(aliases, dict):
aliases = {}
Expand Down Expand Up @@ -506,11 +525,11 @@ def _build_policyengine_bundle(
requested_data=requested_dataset,
requested_data_version=requested_data_version,
)
data_version = (
requested_data_version
if requested_data_version is not None
else _revision_from_dataset_uri(resolved_dataset)
or country_bundle.get("data_version")
data_version = _bundle_response_data_version(
country_bundle=country_bundle,
requested_dataset=requested_dataset,
requested_data_version=requested_data_version,
resolved_dataset=resolved_dataset,
)
model_version = country_bundle.get("model_version") or resolution.response_version
policyengine_version = app_bundle.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class CountryBundleMetadata(TypedDict):
default_dataset_uri: str
dataset_uris: dict[str, str]
dataset_repo_types: dict[str, str]
dataset_aliases: dict[str, str]


class BundleManifestMetadata(TypedDict):
Expand Down Expand Up @@ -79,10 +78,7 @@ def _is_newer_version(candidate: str, current: str | None) -> bool:


def _country_bundle_metadata(country: str) -> CountryBundleMetadata:
from policyengine_api_simulation.release_bundle import (
DATASET_ALIASES,
get_country_release_bundle,
)
from policyengine_api_simulation.release_bundle import get_country_release_bundle

bundle = get_country_release_bundle(country)
return {
Expand All @@ -96,7 +92,6 @@ def _country_bundle_metadata(country: str) -> CountryBundleMetadata:
"default_dataset_uri": bundle.default_dataset_uri,
"dataset_uris": dict(bundle.dataset_uris),
"dataset_repo_types": dict(bundle.dataset_repo_types),
"dataset_aliases": dict(DATASET_ALIASES.get(bundle.country, {})),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,6 @@

SUPPORTED_COUNTRIES = frozenset({"us", "uk"})
BUNDLE_RECEIPT_FILENAME = ".policyengine-bundle-receipt.json"
LEGACY_US_DATA_REVISION = "1.110.12"
LEGACY_ENHANCED_CPS_URI = (
"hf://policyengine/policyengine-us-data/"
f"enhanced_cps_2024.h5@{LEGACY_US_DATA_REVISION}"
)

DATASET_ALIASES: dict[str, dict[str, str]] = {
"us": {
"enhanced_cps": LEGACY_ENHANCED_CPS_URI,
"enhanced_cps_2024": LEGACY_ENHANCED_CPS_URI,
"cps_small": "cps_small_2024",
"cps_small_2024": "cps_small_2024",
"cps": (
"hf://policyengine/policyengine-us-data/"
f"cps_2023.h5@{LEGACY_US_DATA_REVISION}"
),
"cps_2023": (
"hf://policyengine/policyengine-us-data/"
f"cps_2023.h5@{LEGACY_US_DATA_REVISION}"
),
"pooled_cps": (
"hf://policyengine/policyengine-us-data/"
f"pooled_3_year_cps_2023.h5@{LEGACY_US_DATA_REVISION}"
),
"pooled_3_year_cps_2023": (
"hf://policyengine/policyengine-us-data/"
f"pooled_3_year_cps_2023.h5@{LEGACY_US_DATA_REVISION}"
),
},
"uk": {
"enhanced_frs": "enhanced_frs_2023_24",
"enhanced_frs_2023_24": "enhanced_frs_2023_24",
"frs": "frs_2023_24",
"frs_2023_24": "frs_2023_24",
},
}


@dataclass(frozen=True)
Expand Down Expand Up @@ -332,20 +296,15 @@ def resolve_bundle_dataset_name(country: str, requested_data: str | None) -> str
return requested_data

requested_without_revision, revision = _split_requested_revision(requested_data)
aliased = DATASET_ALIASES.get(bundle.country, {}).get(
requested_without_revision, requested_data
)
if revision is not None:
if "://" in aliased:
return _with_hf_revision_unvalidated(aliased, revision)
uri = bundle.dataset_uris.get(aliased)
uri = bundle.dataset_uris.get(requested_without_revision)
if uri is None:
raise ValueError(
"Unknown dataset revision reference "
f"{requested_data!r} for country {bundle.country!r}"
)
return _with_hf_revision_unvalidated(uri, revision)
return aliased
return requested_without_revision


def resolve_bundle_dataset_uri(country: str, requested_data: str | None) -> str:
Expand Down Expand Up @@ -402,11 +361,10 @@ def _is_default_bundle_dataset(
requested_revision=requested_revision,
requested_data_version=requested_data_version,
)
aliased = DATASET_ALIASES.get(bundle.country, {}).get(
requested_without_revision,
requested_without_revision,
)
return aliased == bundle.default_dataset and revision in {None, bundle.data_version}
return requested_without_revision == bundle.default_dataset and revision in {
None,
bundle.data_version,
}


def resolve_local_bundle_dataset_path(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_create_initial_batch_state_builds_queued_years_and_run_id():
start_year="2026",
window_size=3,
max_parallel=2,
data="enhanced_cps_2024",
data="custom_dataset_label",
scope="macro",
reform={},
_telemetry={
Expand All @@ -52,7 +52,7 @@ def test_create_initial_batch_state_builds_queued_years_and_run_id():
assert state.target == "general"
assert state.years == ["2026", "2027", "2028"]
assert state.queued_years == ["2026", "2027", "2028"]
assert state.request_payload["data"] == "enhanced_cps_2024"
assert state.request_payload["data"] == "custom_dataset_label"
assert state.request_payload["scope"] == "macro"
assert state.request_payload["reform"] == {}
assert state.run_id == "batch-run-123"
Expand Down
Loading