diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 7c7af2defc..5fbd923437 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -52,7 +52,7 @@ ModelCard, ModelPackageModelCard, ) -from sagemaker.core.utils.utils import logger +from sagemaker.core.utils.utils import logger, Unassigned from sagemaker.core.helper import session_helper from sagemaker.core.helper.session_helper import ( Session, @@ -414,6 +414,7 @@ def __post_init__(self) -> None: if self.log_level is not None: logger.setLevel(self.log_level) + self._base_model_fields_resolved: bool = False self._warn_about_deprecated_parameters(warnings) self._initialize_compute_config() self._initialize_network_config() @@ -680,18 +681,182 @@ def _infer_instance_type_from_jumpstart(self) -> str: raise ValueError(error_msg) + def _resolve_base_model_fields(self): + """Auto-resolve missing BaseModel fields (hub_content_version, recipe_name). + + When a ModelPackage's BaseModel has hub_content_name set but is missing + hub_content_version and/or recipe_name (returned as Unassigned from the + DescribeModelPackage API), this method resolves them automatically: + - hub_content_version: resolved by calling HubContent.get on SageMakerPublicHub + - recipe_name: resolved from the first recipe in the hub document's RecipeCollection + + Note: HubContent.get() supports being called without hub_content_version, + in which case it returns the latest version of the hub content. + """ + if self._base_model_fields_resolved: + return + + model_package = self._fetch_model_package() + if not model_package: + self._base_model_fields_resolved = True + return + + inference_spec = getattr(model_package, "inference_specification", None) + if not inference_spec: + self._base_model_fields_resolved = True + return + + containers = getattr(inference_spec, "containers", None) + if not containers: + self._base_model_fields_resolved = True + return + + base_model = getattr(containers[0], "base_model", None) + if not base_model: + self._base_model_fields_resolved = True + return + + hub_content_name = getattr(base_model, "hub_content_name", None) + if not hub_content_name or isinstance(hub_content_name, Unassigned): + self._base_model_fields_resolved = True + return + + hub_content_version = getattr(base_model, "hub_content_version", None) + recipe_name = getattr(base_model, "recipe_name", None) + + # Cache the HubContent response to avoid redundant API calls + cached_hub_content = None + + # Resolve hub_content_version if missing + if not hub_content_version or isinstance(hub_content_version, Unassigned): + logger.info( + "hub_content_version is missing for hub content '%s'. " + "Resolving automatically from SageMakerPublicHub...", + hub_content_name, + ) + try: + cached_hub_content = HubContent.get( + hub_content_type="Model", + hub_name="SageMakerPublicHub", + hub_content_name=hub_content_name, + ) + base_model.hub_content_version = ( + cached_hub_content.hub_content_version + ) + logger.info( + "Resolved hub_content_version to '%s' " + "for hub content '%s'.", + cached_hub_content.hub_content_version, + hub_content_name, + ) + except (ClientError, ValueError) as e: + logger.warning( + "Failed to resolve hub_content_version " + "for hub content '%s': %s", + hub_content_name, + e, + ) + self._base_model_fields_resolved = True + return + + # Resolve recipe_name if missing + if not recipe_name or isinstance(recipe_name, Unassigned): + logger.info( + "recipe_name is missing for hub content '%s'. " + "Resolving from hub document RecipeCollection...", + hub_content_name, + ) + try: + # Reuse cached hub content if available and version matches + if ( + cached_hub_content is not None + and cached_hub_content.hub_content_version + == base_model.hub_content_version + ): + hub_content = cached_hub_content + else: + hub_content = HubContent.get( + hub_content_type="Model", + hub_name="SageMakerPublicHub", + hub_content_name=base_model.hub_content_name, + hub_content_version=( + base_model.hub_content_version + ), + ) + hub_document = json.loads( + hub_content.hub_content_document + ) + recipe_collection = hub_document.get( + "RecipeCollection", [] + ) + if recipe_collection: + resolved_recipe = recipe_collection[0].get( + "Name", "" + ) + if resolved_recipe: + base_model.recipe_name = resolved_recipe + logger.info( + "Resolved recipe_name to '%s' " + "for hub content '%s'.", + resolved_recipe, + hub_content_name, + ) + else: + logger.warning( + "RecipeCollection found but first " + "recipe has no Name for hub " + "content '%s'.", + hub_content_name, + ) + else: + logger.warning( + "No RecipeCollection found in hub " + "document for hub content '%s'. " + "recipe_name could not be " + "auto-resolved.", + hub_content_name, + ) + except (ClientError, ValueError) as e: + logger.warning( + "Failed to resolve recipe_name " + "for hub content '%s': %s", + hub_content_name, + e, + ) + + self._base_model_fields_resolved = True + def _fetch_hub_document_for_custom_model(self) -> dict: + """Fetch the hub document for a custom (fine-tuned) model. + + Calls _resolve_base_model_fields() first to ensure hub_content_version + is populated. If hub_content_version is still Unassigned after + resolution (e.g. resolution failed), HubContent.get() is called + without a version parameter, which returns the latest version. + """ from sagemaker.core.shapes import BaseModel as CoreBaseModel + self._resolve_base_model_fields() + base_model: CoreBaseModel = ( - self._fetch_model_package().inference_specification.containers[0].base_model + self._fetch_model_package() + .inference_specification.containers[0].base_model + ) + + hub_content_version = getattr( + base_model, "hub_content_version", None ) - hub_content = HubContent.get( + get_kwargs = dict( hub_content_type="Model", hub_name="SageMakerPublicHub", hub_content_name=base_model.hub_content_name, - hub_content_version=base_model.hub_content_version, ) + if hub_content_version and not isinstance( + hub_content_version, Unassigned + ): + get_kwargs["hub_content_version"] = hub_content_version + + hub_content = HubContent.get(**get_kwargs) return json.loads(hub_content.hub_content_document) def _fetch_hosting_configs_for_custom_model(self) -> dict: @@ -937,9 +1102,29 @@ def _is_gpu_instance(self, instance_type: str) -> bool: def _fetch_and_cache_recipe_config(self): """Fetch and cache image URI, compute requirements, and s3_upload_path from recipe during build.""" + # _fetch_hub_document_for_custom_model calls _resolve_base_model_fields + # internally, so no need to call it separately here. hub_document = self._fetch_hub_document_for_custom_model() model_package = self._fetch_model_package() - recipe_name = model_package.inference_specification.containers[0].base_model.recipe_name + base_model = ( + model_package.inference_specification + .containers[0].base_model + ) + hub_content_name = getattr( + base_model, "hub_content_name", "unknown" + ) + recipe_name = getattr(base_model, "recipe_name", None) + if not recipe_name or isinstance(recipe_name, Unassigned): + raise ValueError( + f"recipe_name is missing from the model package's " + f"BaseModel (hub_content_name='{hub_content_name}') " + f"and could not be auto-resolved from the hub " + f"document. Please ensure the model package has a " + f"valid recipe_name set, or manually set it before " + f"calling build(). Example: model_package." + f"inference_specification.containers[0].base_model" + f".recipe_name = 'your-recipe-name'" + ) if not self.s3_upload_path: self.s3_upload_path = model_package.inference_specification.containers[ @@ -1060,7 +1245,11 @@ def _is_nova_model(self) -> bool: if not base_model: return False recipe_name = getattr(base_model, "recipe_name", "") or "" + if isinstance(recipe_name, Unassigned): + recipe_name = "" hub_content_name = getattr(base_model, "hub_content_name", "") or "" + if isinstance(hub_content_name, Unassigned): + hub_content_name = "" return "nova" in recipe_name.lower() or "nova" in hub_content_name.lower() def _is_nova_model_for_telemetry(self) -> bool: @@ -1076,8 +1265,12 @@ def _get_nova_hosting_config(self, instance_type=None): Nova training recipes don't have hosting configs in the JumpStart hub document. This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs(). """ + self._resolve_base_model_fields() model_package = self._fetch_model_package() - hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name + hub_content_name = ( + model_package.inference_specification + .containers[0].base_model.hub_content_name + ) configs = self._NOVA_HOSTING_CONFIGS.get(hub_content_name) if not configs: diff --git a/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py b/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py new file mode 100644 index 0000000000..203d389e4d --- /dev/null +++ b/sagemaker-serve/tests/unit/test_resolve_base_model_fields.py @@ -0,0 +1,471 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for _resolve_base_model_fields and related Unassigned handling.""" +from __future__ import annotations + +import json +import pytest +from unittest.mock import MagicMock, patch +from botocore.exceptions import ClientError + +from sagemaker.core.utils.utils import Unassigned + + +def _make_model_builder(**kwargs): + """Create a ModelBuilder with mocked session to avoid real AWS calls.""" + with patch("sagemaker.serve.model_builder.Session"): + with patch( + "sagemaker.serve.model_builder.get_execution_role", + return_value=( + "arn:aws:iam::123456789012:role/SageMakerRole" + ), + ): + from sagemaker.serve.model_builder import ModelBuilder + defaults = dict( + model="dummy-model", + role_arn=( + "arn:aws:iam::123456789012:role/SageMakerRole" + ), + ) + defaults.update(kwargs) + mb = ModelBuilder(**defaults) + # Reset the resolution flag so tests can trigger it + mb._base_model_fields_resolved = False + return mb + + +def _make_base_model( + hub_content_name=None, + hub_content_version=None, + recipe_name=None, +): + """Create a mock BaseModel with the given fields.""" + base_model = MagicMock() + base_model.hub_content_name = ( + hub_content_name + if hub_content_name is not None + else Unassigned() + ) + base_model.hub_content_version = ( + hub_content_version + if hub_content_version is not None + else Unassigned() + ) + base_model.recipe_name = ( + recipe_name + if recipe_name is not None + else Unassigned() + ) + return base_model + + +def _make_model_package(base_model): + """Create a mock ModelPackage with the given base_model.""" + container = MagicMock() + container.base_model = base_model + container.model_data_source = MagicMock() + container.model_data_source.s3_data_source = MagicMock() + container.model_data_source.s3_data_source.s3_uri = ( + "s3://bucket/path" + ) + + model_package = MagicMock() + model_package.inference_specification.containers = [container] + return model_package + + +def _make_hub_content( + hub_content_version="1.0.0", + hub_content_document=None, +): + """Create a mock HubContent object.""" + hc = MagicMock() + hc.hub_content_version = hub_content_version + if hub_content_document is None: + hub_content_document = json.dumps({ + "RecipeCollection": [ + { + "Name": "auto-resolved-recipe", + "HostingConfigs": [], + } + ], + "HostingConfigs": [], + }) + hc.hub_content_document = hub_content_document + return hc + + +class TestResolveBaseModelFields: + """Tests for _resolve_base_model_fields method.""" + + @patch("sagemaker.serve.model_builder.HubContent") + def test_resolve_missing_hub_content_version( + self, mock_hub_content_cls + ): + """hub_content_version Unassigned => resolved from HubContent.get.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name="some-recipe", + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + mock_hc = _make_hub_content(hub_content_version="2.5.0") + mock_hub_content_cls.get.return_value = mock_hc + + mb._resolve_base_model_fields() + + assert base_model.hub_content_version == "2.5.0" + assert base_model.recipe_name == "some-recipe" + + @patch("sagemaker.serve.model_builder.HubContent") + def test_resolve_missing_recipe_name( + self, mock_hub_content_cls + ): + """recipe_name Unassigned => resolved from RecipeCollection.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version="1.0.0", + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + hub_doc = json.dumps({ + "RecipeCollection": [ + { + "Name": "verl-grpo-rlaif-qwen-3-32b-lora", + "HostingConfigs": [], + } + ], + }) + mock_hc = _make_hub_content( + hub_content_version="1.0.0", + hub_content_document=hub_doc, + ) + mock_hub_content_cls.get.return_value = mock_hc + + mb._resolve_base_model_fields() + + assert base_model.recipe_name == ( + "verl-grpo-rlaif-qwen-3-32b-lora" + ) + + @patch("sagemaker.serve.model_builder.HubContent") + def test_noop_when_all_fields_present( + self, mock_hub_content_cls + ): + """All fields present => HubContent.get not called.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version="1.0.0", + recipe_name="some-recipe", + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + mb._resolve_base_model_fields() + + mock_hub_content_cls.get.assert_not_called() + assert base_model.hub_content_version == "1.0.0" + assert base_model.recipe_name == "some-recipe" + + @patch("sagemaker.serve.model_builder.HubContent") + def test_resolve_both_version_and_recipe( + self, mock_hub_content_cls + ): + """Both Unassigned => both resolved.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + hub_doc = json.dumps({ + "RecipeCollection": [ + { + "Name": "auto-resolved-recipe", + "HostingConfigs": [], + } + ], + }) + mock_hc = _make_hub_content( + hub_content_version="3.0.0", + hub_content_document=hub_doc, + ) + mock_hub_content_cls.get.return_value = mock_hc + + mb._resolve_base_model_fields() + + assert base_model.hub_content_version == "3.0.0" + assert base_model.recipe_name == "auto-resolved-recipe" + + @patch("sagemaker.serve.model_builder.HubContent") + def test_fetch_hub_document_works_after_resolution( + self, mock_hub_content_cls + ): + """_fetch_hub_document_for_custom_model works after resolution.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name="some-recipe", + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + hub_doc = json.dumps( + {"HostingConfigs": [{"Profile": "Default"}]} + ) + mock_hc = _make_hub_content( + hub_content_version="1.0.0", + hub_content_document=hub_doc, + ) + mock_hub_content_cls.get.return_value = mock_hc + + result = mb._fetch_hub_document_for_custom_model() + + assert result == { + "HostingConfigs": [{"Profile": "Default"}] + } + + @patch("sagemaker.serve.model_builder.HubContent") + def test_no_base_model_is_noop( + self, mock_hub_content_cls + ): + """No base_model => method returns without error.""" + mb = _make_model_builder() + container = MagicMock() + container.base_model = None + model_package = MagicMock() + model_package.inference_specification.containers = [ + container + ] + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + mb._resolve_base_model_fields() + + mock_hub_content_cls.get.assert_not_called() + + @patch("sagemaker.serve.model_builder.HubContent") + def test_no_hub_content_name_is_noop( + self, mock_hub_content_cls + ): + """hub_content_name Unassigned => no HubContent.get call.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name=None, + hub_content_version=None, + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + mb._resolve_base_model_fields() + + mock_hub_content_cls.get.assert_not_called() + + @patch("sagemaker.serve.model_builder.HubContent") + def test_is_nova_model_with_unassigned_fields( + self, mock_hub_content_cls + ): + """_is_nova_model returns False when fields are Unassigned.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name=None, + hub_content_version=None, + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + result = mb._is_nova_model() + + assert result is False + + @patch("sagemaker.serve.model_builder.HubContent") + def test_fetch_and_cache_recipe_raises_when_unresolvable( + self, mock_hub_content_cls + ): + """recipe_name unresolvable => ValueError from _fetch_and_cache.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version="1.0.0", + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + hub_doc = json.dumps( + {"RecipeCollection": [], "HostingConfigs": []} + ) + mock_hc = _make_hub_content( + hub_content_version="1.0.0", + hub_content_document=hub_doc, + ) + mock_hub_content_cls.get.return_value = mock_hc + + with pytest.raises(ValueError, match="recipe_name is missing"): + mb._fetch_and_cache_recipe_config() + + @patch("sagemaker.serve.model_builder.HubContent") + def test_resolve_graceful_on_hub_content_get_failure( + self, mock_hub_content_cls + ): + """When HubContent.get fails, resolution returns early. + + Downstream code (_fetch_and_cache_recipe_config) should still + raise the appropriate ValueError because recipe_name remains + Unassigned. + """ + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + # Simulate HubContent.get raising a ClientError + mock_hub_content_cls.get.side_effect = ClientError( + error_response={ + "Error": { + "Code": "ResourceNotFoundException", + "Message": "Hub content not found", + } + }, + operation_name="DescribeHubContent", + ) + + # _resolve_base_model_fields should not raise + mb._resolve_base_model_fields() + + # Fields should still be Unassigned + assert isinstance( + base_model.hub_content_version, Unassigned + ) + assert isinstance(base_model.recipe_name, Unassigned) + # The flag should be set so it doesn't retry + assert mb._base_model_fields_resolved is True + + @patch("sagemaker.serve.model_builder.HubContent") + def test_resolve_failure_then_fetch_and_cache_raises( + self, mock_hub_content_cls + ): + """When resolution fails, _fetch_and_cache_recipe_config raises. + + This tests the full flow: resolution fails gracefully, then + downstream code raises ValueError because recipe_name is still + Unassigned. + """ + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name=None, + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + # First call (from _resolve_base_model_fields) fails + # Second call (from _fetch_hub_document_for_custom_model) + # succeeds but returns empty RecipeCollection + hub_doc = json.dumps( + {"RecipeCollection": [], "HostingConfigs": []} + ) + mock_hc = _make_hub_content( + hub_content_version="1.0.0", + hub_content_document=hub_doc, + ) + + call_count = [0] + + def side_effect(**kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise ClientError( + error_response={ + "Error": { + "Code": "ResourceNotFoundException", + "Message": "Not found", + } + }, + operation_name="DescribeHubContent", + ) + return mock_hc + + mock_hub_content_cls.get.side_effect = side_effect + + with pytest.raises( + ValueError, match="recipe_name is missing" + ): + mb._fetch_and_cache_recipe_config() + + @patch("sagemaker.serve.model_builder.HubContent") + def test_idempotent_second_call_is_noop( + self, mock_hub_content_cls + ): + """Second call to _resolve_base_model_fields is a no-op.""" + mb = _make_model_builder() + base_model = _make_base_model( + hub_content_name="huggingface-reasoning-qwen3-32b", + hub_content_version=None, + recipe_name="some-recipe", + ) + model_package = _make_model_package(base_model) + mb._fetch_model_package = MagicMock( + return_value=model_package + ) + + mock_hc = _make_hub_content(hub_content_version="2.0.0") + mock_hub_content_cls.get.return_value = mock_hc + + mb._resolve_base_model_fields() + assert base_model.hub_content_version == "2.0.0" + + # Reset mock to verify no additional calls + mock_hub_content_cls.get.reset_mock() + + mb._resolve_base_model_fields() + mock_hub_content_cls.get.assert_not_called()