From 05e59de8d24c5a3dc4b62f41b0afe1872ff15fc3 Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Tue, 14 Apr 2026 10:47:07 -0700 Subject: [PATCH 1/7] fix: ModelBuilder.deploy() should expose DataCacheConfig and other CreateInferenceCom (5750) --- .../src/sagemaker/serve/model_builder.py | 2 + .../sagemaker/serve/model_builder_utils.py | 65 +++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 7c7af2defc..0d24accd98 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -45,6 +45,8 @@ ModelLifeCycle, DriftCheckBaselines, InferenceComponentComputeResourceRequirements, + InferenceComponentDataCacheConfig, + InferenceComponentContainerSpecification, ) from sagemaker.core.resources import ( ModelPackage, diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 56f3070346..25f4e44c26 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -3369,6 +3369,71 @@ def _extract_speculative_draft_model_provider( return "auto" + def _resolve_data_cache_config(self, data_cache_config): + """Resolve data_cache_config to InferenceComponentDataCacheConfig. + + Args: + data_cache_config: Either a dict with 'enable_caching' key, + an InferenceComponentDataCacheConfig instance, or None. + + Returns: + InferenceComponentDataCacheConfig or None. + + Raises: + ValueError: If data_cache_config is an unsupported type. + """ + if data_cache_config is None: + return None + + from sagemaker.core.shapes import InferenceComponentDataCacheConfig + + if isinstance(data_cache_config, InferenceComponentDataCacheConfig): + return data_cache_config + elif isinstance(data_cache_config, dict): + return InferenceComponentDataCacheConfig( + enable_caching=data_cache_config.get("enable_caching", False) + ) + else: + raise ValueError( + f"data_cache_config must be a dict with 'enable_caching' key or an " + f"InferenceComponentDataCacheConfig instance, got {type(data_cache_config)}" + ) + + def _resolve_container_spec(self, container): + """Resolve container to InferenceComponentContainerSpecification. + + Args: + container: Either a dict with container config keys (image, artifact_url, + environment), an InferenceComponentContainerSpecification instance, or None. + + Returns: + InferenceComponentContainerSpecification or None. + + Raises: + ValueError: If container is an unsupported type. + """ + if container is None: + return None + + from sagemaker.core.shapes import InferenceComponentContainerSpecification + + if isinstance(container, InferenceComponentContainerSpecification): + return container + elif isinstance(container, dict): + kwargs = {} + if "image" in container: + kwargs["image"] = container["image"] + if "artifact_url" in container: + kwargs["artifact_url"] = container["artifact_url"] + if "environment" in container: + kwargs["environment"] = container["environment"] + return InferenceComponentContainerSpecification(**kwargs) + else: + raise ValueError( + f"container must be a dict or an InferenceComponentContainerSpecification " + f"instance, got {type(container)}" + ) + def get_huggingface_model_metadata( self, model_id: str, hf_hub_token: Optional[str] = None ) -> dict: From 8bc8db347fce8e6804972641b61bd86823a2ee2a Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Tue, 14 Apr 2026 11:13:18 -0700 Subject: [PATCH 2/7] fix: address review comments (iteration #1) --- .../src/sagemaker/serve/model_builder.py | 62 ++++++- .../sagemaker/serve/model_builder_utils.py | 28 +++- .../sagemaker/serve/test_resolve_ic_params.py | 151 ++++++++++++++++++ 3 files changed, 232 insertions(+), 9 deletions(-) create mode 100644 tests/unit/sagemaker/serve/test_resolve_ic_params.py diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 0d24accd98..6be20ed7aa 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -2980,6 +2980,34 @@ def _deploy_core_endpoint(self, **kwargs): "StartupParameters": startup_parameters, "ComputeResourceRequirements": resources.get_compute_resource_requirements(), } + + # Wire optional IC-level parameters into the specification + ic_data_cache_config = kwargs.get("data_cache_config") + if ic_data_cache_config is not None: + resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config) + if resolved_cache_config is not None: + inference_component_spec["DataCacheConfig"] = { + "EnableCaching": resolved_cache_config.enable_caching + } + + ic_base_component_name = kwargs.get("base_inference_component_name") + if ic_base_component_name is not None: + inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name + + ic_container = kwargs.get("container") + if ic_container is not None: + resolved_container = self._resolve_container_spec(ic_container) + if resolved_container is not None: + container_dict = {} + if hasattr(resolved_container, "image") and resolved_container.image: + container_dict["Image"] = resolved_container.image + if hasattr(resolved_container, "artifact_url") and resolved_container.artifact_url: + container_dict["ArtifactUrl"] = resolved_container.artifact_url + if hasattr(resolved_container, "environment") and resolved_container.environment: + container_dict["Environment"] = resolved_container.environment + if container_dict: + inference_component_spec["Container"] = container_dict + runtime_config = {"CopyCount": resources.copy_count} self.inference_component_name = ( inference_component_name @@ -2987,11 +3015,14 @@ def _deploy_core_endpoint(self, **kwargs): or unique_name_from_base(self.model_name) ) + # Use user-provided variant_name or default to "AllTraffic" + ic_variant_name = kwargs.get("variant_name", "AllTraffic") + # [TODO]: Add endpoint_logging support self.sagemaker_session.create_inference_component( inference_component_name=self.inference_component_name, endpoint_name=self.endpoint_name, - variant_name="AllTraffic", # default variant name + variant_name=ic_variant_name, specification=inference_component_spec, runtime_config=runtime_config, tags=tags, @@ -4129,6 +4160,10 @@ def deploy( ] = None, custom_orchestrator_instance_type: str = None, custom_orchestrator_initial_instance_count: int = None, + data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None, + base_inference_component_name: Optional[str] = None, + container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None, + variant_name: Optional[str] = None, **kwargs, ) -> Union[Endpoint, LocalEndpoint, Transformer]: """Deploy the built model to an ``Endpoint``. @@ -4162,6 +4197,21 @@ def deploy( orchestrator deployment. (Default: None). custom_orchestrator_initial_instance_count (int, optional): Initial instance count for custom orchestrator deployment. (Default: None). + data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional): + Data cache configuration for the inference component. Enables caching of model + artifacts and container images on instances for faster auto-scaling cold starts. + Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an + InferenceComponentDataCacheConfig instance. (Default: None). + base_inference_component_name (str, optional): Name of the base inference component + for adapter deployments (e.g., LoRA adapters attached to a base model). + (Default: None). + container (Union[InferenceComponentContainerSpecification, dict], optional): + Custom container specification for the inference component, including image URI, + artifact URL, and environment variables. Can be a dict with keys 'image', + 'artifact_url', 'environment' or an InferenceComponentContainerSpecification + instance. (Default: None). + variant_name (str, optional): The name of the production variant to deploy to. + If not specified, defaults to 'AllTraffic'. (Default: None). Returns: Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint`` resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode, @@ -4184,6 +4234,16 @@ def deploy( if not hasattr(self, "built_model") and not hasattr(self, "_deployables"): raise ValueError("Model needs to be built before deploying") + # Store IC-level parameters for use in _deploy_core_endpoint + if data_cache_config is not None: + kwargs["data_cache_config"] = data_cache_config + if base_inference_component_name is not None: + kwargs["base_inference_component_name"] = base_inference_component_name + if container is not None: + kwargs["container"] = container + if variant_name is not None: + kwargs["variant_name"] = variant_name + # Handle model customization deployment if self._is_model_customization(): logger.info("Deploying Model Customization model") diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 25f4e44c26..19966077c1 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -80,6 +80,10 @@ def build(self): from sagemaker.core.resources import Model # MLflow imports +from sagemaker.core.shapes import ( + InferenceComponentDataCacheConfig, + InferenceComponentContainerSpecification, +) from sagemaker.serve.model_format.mlflow.constants import ( MLFLOW_METADATA_FILE, MLFLOW_MODEL_PATH, @@ -3369,7 +3373,10 @@ def _extract_speculative_draft_model_provider( return "auto" - def _resolve_data_cache_config(self, data_cache_config): + def _resolve_data_cache_config( + self, + data_cache_config: Union[InferenceComponentDataCacheConfig, Dict[str, Any], None], + ) -> Optional[InferenceComponentDataCacheConfig]: """Resolve data_cache_config to InferenceComponentDataCacheConfig. Args: @@ -3380,18 +3387,22 @@ def _resolve_data_cache_config(self, data_cache_config): InferenceComponentDataCacheConfig or None. Raises: - ValueError: If data_cache_config is an unsupported type. + ValueError: If data_cache_config is an unsupported type or dict + is missing the required 'enable_caching' key. """ if data_cache_config is None: return None - from sagemaker.core.shapes import InferenceComponentDataCacheConfig - if isinstance(data_cache_config, InferenceComponentDataCacheConfig): return data_cache_config elif isinstance(data_cache_config, dict): + if "enable_caching" not in data_cache_config: + raise ValueError( + "data_cache_config dict must contain the required 'enable_caching' key. " + "Example: {'enable_caching': True}" + ) return InferenceComponentDataCacheConfig( - enable_caching=data_cache_config.get("enable_caching", False) + enable_caching=data_cache_config["enable_caching"] ) else: raise ValueError( @@ -3399,7 +3410,10 @@ def _resolve_data_cache_config(self, data_cache_config): f"InferenceComponentDataCacheConfig instance, got {type(data_cache_config)}" ) - def _resolve_container_spec(self, container): + def _resolve_container_spec( + self, + container: Union[InferenceComponentContainerSpecification, Dict[str, Any], None], + ) -> Optional[InferenceComponentContainerSpecification]: """Resolve container to InferenceComponentContainerSpecification. Args: @@ -3415,8 +3429,6 @@ def _resolve_container_spec(self, container): if container is None: return None - from sagemaker.core.shapes import InferenceComponentContainerSpecification - if isinstance(container, InferenceComponentContainerSpecification): return container elif isinstance(container, dict): diff --git a/tests/unit/sagemaker/serve/test_resolve_ic_params.py b/tests/unit/sagemaker/serve/test_resolve_ic_params.py new file mode 100644 index 0000000000..a326acff18 --- /dev/null +++ b/tests/unit/sagemaker/serve/test_resolve_ic_params.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for _resolve_data_cache_config and _resolve_container_spec.""" +from __future__ import absolute_import + +import pytest + +from sagemaker.core.shapes import ( + InferenceComponentDataCacheConfig, + InferenceComponentContainerSpecification, +) +from sagemaker.serve.model_builder_utils import _ModelBuilderUtils + + +class ConcreteUtils(_ModelBuilderUtils): + """Concrete class to test mixin methods.""" + pass + + +@pytest.fixture +def utils(): + return ConcreteUtils() + + +# ============================================================ +# Tests for _resolve_data_cache_config +# ============================================================ + +class TestResolveDataCacheConfig: + def test_none_returns_none(self, utils): + assert utils._resolve_data_cache_config(None) is None + + def test_already_typed_passthrough(self, utils): + config = InferenceComponentDataCacheConfig(enable_caching=True) + result = utils._resolve_data_cache_config(config) + assert result is config + assert result.enable_caching is True + + def test_dict_with_enable_caching_true(self, utils): + result = utils._resolve_data_cache_config({"enable_caching": True}) + assert isinstance(result, InferenceComponentDataCacheConfig) + assert result.enable_caching is True + + def test_dict_with_enable_caching_false(self, utils): + result = utils._resolve_data_cache_config({"enable_caching": False}) + assert isinstance(result, InferenceComponentDataCacheConfig) + assert result.enable_caching is False + + def test_dict_missing_enable_caching_raises(self, utils): + with pytest.raises(ValueError, match="must contain the required 'enable_caching' key"): + utils._resolve_data_cache_config({}) + + def test_dict_with_extra_keys_still_works(self, utils): + """Extra keys are ignored; only enable_caching is required.""" + result = utils._resolve_data_cache_config( + {"enable_caching": True, "extra_key": "ignored"} + ) + assert isinstance(result, InferenceComponentDataCacheConfig) + assert result.enable_caching is True + + def test_invalid_type_raises(self, utils): + with pytest.raises(ValueError, match="data_cache_config must be a dict"): + utils._resolve_data_cache_config("invalid") + + def test_invalid_type_int_raises(self, utils): + with pytest.raises(ValueError, match="data_cache_config must be a dict"): + utils._resolve_data_cache_config(42) + + def test_invalid_type_list_raises(self, utils): + with pytest.raises(ValueError, match="data_cache_config must be a dict"): + utils._resolve_data_cache_config([True]) + + +# ============================================================ +# Tests for _resolve_container_spec +# ============================================================ + +class TestResolveContainerSpec: + def test_none_returns_none(self, utils): + assert utils._resolve_container_spec(None) is None + + def test_already_typed_passthrough(self, utils): + spec = InferenceComponentContainerSpecification( + image="my-image:latest", + artifact_url="s3://bucket/artifact", + environment={"KEY": "VALUE"}, + ) + result = utils._resolve_container_spec(spec) + assert result is spec + + def test_dict_full(self, utils): + result = utils._resolve_container_spec({ + "image": "my-image:latest", + "artifact_url": "s3://bucket/artifact", + "environment": {"KEY": "VALUE"}, + }) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.image == "my-image:latest" + assert result.artifact_url == "s3://bucket/artifact" + assert result.environment == {"KEY": "VALUE"} + + def test_dict_image_only(self, utils): + result = utils._resolve_container_spec({"image": "my-image:latest"}) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.image == "my-image:latest" + + def test_dict_artifact_url_only(self, utils): + result = utils._resolve_container_spec({"artifact_url": "s3://bucket/model.tar.gz"}) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.artifact_url == "s3://bucket/model.tar.gz" + + def test_dict_environment_only(self, utils): + result = utils._resolve_container_spec({"environment": {"A": "B"}}) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.environment == {"A": "B"} + + def test_dict_empty(self, utils): + """Empty dict creates a spec with no fields set.""" + result = utils._resolve_container_spec({}) + assert isinstance(result, InferenceComponentContainerSpecification) + + def test_dict_with_extra_keys(self, utils): + """Extra keys are ignored.""" + result = utils._resolve_container_spec({ + "image": "img", + "unknown_key": "ignored", + }) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.image == "img" + + def test_invalid_type_raises(self, utils): + with pytest.raises(ValueError, match="container must be a dict"): + utils._resolve_container_spec("invalid") + + def test_invalid_type_int_raises(self, utils): + with pytest.raises(ValueError, match="container must be a dict"): + utils._resolve_container_spec(123) + + def test_invalid_type_list_raises(self, utils): + with pytest.raises(ValueError, match="container must be a dict"): + utils._resolve_container_spec([{"image": "img"}]) From ccc3425a3663d2c72bd4f6f7100348cd84fefd69 Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Tue, 14 Apr 2026 11:17:17 -0700 Subject: [PATCH 3/7] fix: address review comments (iteration #2) --- .../src/sagemaker/serve/model_builder.py | 15 +- .../sagemaker/serve/model_builder_utils.py | 13 +- .../sagemaker/serve/test_resolve_ic_params.py | 243 +++++++++++++++++- 3 files changed, 253 insertions(+), 18 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 6be20ed7aa..3d0ce25571 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -2986,9 +2986,9 @@ def _deploy_core_endpoint(self, **kwargs): if ic_data_cache_config is not None: resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config) if resolved_cache_config is not None: - inference_component_spec["DataCacheConfig"] = { - "EnableCaching": resolved_cache_config.enable_caching - } + cache_dict = {"EnableCaching": resolved_cache_config.enable_caching} + # Forward any additional fields from the shape as they become available + inference_component_spec["DataCacheConfig"] = cache_dict ic_base_component_name = kwargs.get("base_inference_component_name") if ic_base_component_name is not None: @@ -2999,11 +2999,11 @@ def _deploy_core_endpoint(self, **kwargs): resolved_container = self._resolve_container_spec(ic_container) if resolved_container is not None: container_dict = {} - if hasattr(resolved_container, "image") and resolved_container.image: + if resolved_container.image: container_dict["Image"] = resolved_container.image - if hasattr(resolved_container, "artifact_url") and resolved_container.artifact_url: + if resolved_container.artifact_url: container_dict["ArtifactUrl"] = resolved_container.artifact_url - if hasattr(resolved_container, "environment") and resolved_container.environment: + if resolved_container.environment: container_dict["Environment"] = resolved_container.environment if container_dict: inference_component_spec["Container"] = container_dict @@ -4211,7 +4211,8 @@ def deploy( 'artifact_url', 'environment' or an InferenceComponentContainerSpecification instance. (Default: None). variant_name (str, optional): The name of the production variant to deploy to. - If not specified, defaults to 'AllTraffic'. (Default: None). + If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``. + (Default: None). Returns: Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint`` resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode, diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 19966077c1..1cc7383ef4 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -3432,14 +3432,11 @@ def _resolve_container_spec( if isinstance(container, InferenceComponentContainerSpecification): return container elif isinstance(container, dict): - kwargs = {} - if "image" in container: - kwargs["image"] = container["image"] - if "artifact_url" in container: - kwargs["artifact_url"] = container["artifact_url"] - if "environment" in container: - kwargs["environment"] = container["environment"] - return InferenceComponentContainerSpecification(**kwargs) + # Only pass known keys to avoid Pydantic validation errors + # if the model has extra='forbid' configured + known_keys = {"image", "artifact_url", "environment"} + filtered = {k: v for k, v in container.items() if k in known_keys} + return InferenceComponentContainerSpecification(**filtered) else: raise ValueError( f"container must be a dict or an InferenceComponentContainerSpecification " diff --git a/tests/unit/sagemaker/serve/test_resolve_ic_params.py b/tests/unit/sagemaker/serve/test_resolve_ic_params.py index a326acff18..36ef2bf5e3 100644 --- a/tests/unit/sagemaker/serve/test_resolve_ic_params.py +++ b/tests/unit/sagemaker/serve/test_resolve_ic_params.py @@ -10,10 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Unit tests for _resolve_data_cache_config and _resolve_container_spec.""" +"""Unit tests for IC parameter resolvers and wiring logic.""" from __future__ import absolute_import import pytest +from unittest.mock import MagicMock, patch, ANY from sagemaker.core.shapes import ( InferenceComponentDataCacheConfig, @@ -61,7 +62,11 @@ def test_dict_missing_enable_caching_raises(self, utils): utils._resolve_data_cache_config({}) def test_dict_with_extra_keys_still_works(self, utils): - """Extra keys are ignored; only enable_caching is required.""" + """Extra keys in the input dict are ignored (not forwarded to the Pydantic constructor). + + The resolver only extracts 'enable_caching' from the dict, so extra keys + do not cause Pydantic validation errors even if the model forbids extras. + """ result = utils._resolve_data_cache_config( {"enable_caching": True, "extra_key": "ignored"} ) @@ -130,7 +135,11 @@ def test_dict_empty(self, utils): assert isinstance(result, InferenceComponentContainerSpecification) def test_dict_with_extra_keys(self, utils): - """Extra keys are ignored.""" + """Extra keys are filtered out before passing to the Pydantic constructor. + + This ensures compatibility even if InferenceComponentContainerSpecification + has extra='forbid' in its Pydantic model config. + """ result = utils._resolve_container_spec({ "image": "img", "unknown_key": "ignored", @@ -149,3 +158,231 @@ def test_invalid_type_int_raises(self, utils): def test_invalid_type_list_raises(self, utils): with pytest.raises(ValueError, match="container must be a dict"): utils._resolve_container_spec([{"image": "img"}]) + + +# ============================================================ +# Tests for core wiring logic in _deploy_core_endpoint +# ============================================================ + +class TestDeployCoreEndpointWiring: + """Tests that new IC parameters are correctly wired through _deploy_core_endpoint.""" + + def _make_model_builder(self): + """Create a minimally-configured ModelBuilder for testing _deploy_core_endpoint.""" + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + # Set minimum required attributes + mb.model_name = "test-model" + mb.endpoint_name = None + mb.inference_component_name = None + mb.instance_type = "ml.g5.2xlarge" + mb.instance_count = 1 + mb.accelerator_type = None + mb._tags = None + mb.kms_key = None + mb.async_inference_config = None + mb.serverless_inference_config = None + mb.model_data_download_timeout = None + mb.resource_requirements = None + mb.container_startup_health_check_timeout = None + mb.inference_ami_version = None + mb._is_sharded_model = False + mb._enable_network_isolation = False + mb.role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" + mb.vpc_config = None + mb.inference_recommender_job_results = None + mb.model_server = None + mb.mode = None + mb.region = "us-east-1" + + # Mock built_model + mb.built_model = MagicMock() + mb.built_model.model_name = "test-model" + + # Mock sagemaker_session + mb.sagemaker_session = MagicMock() + mb.sagemaker_session.endpoint_in_service_or_not.return_value = True + mb.sagemaker_session.boto_session = MagicMock() + mb.sagemaker_session.boto_region_name = "us-east-1" + + return mb + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_variant_name_defaults_to_all_traffic(self, mock_endpoint_cls): + """When variant_name is not provided, it defaults to 'AllTraffic'.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + wait=False, + ) + + # Verify create_inference_component was called with variant_name="AllTraffic" + mb.sagemaker_session.create_inference_component.assert_called_once() + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + assert call_kwargs[1]["variant_name"] == "AllTraffic" or \ + (len(call_kwargs[0]) > 2 and call_kwargs[0][2] == "AllTraffic") + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_variant_name_custom(self, mock_endpoint_cls): + """When variant_name is provided, it is used instead of 'AllTraffic'.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + variant_name="MyVariant", + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + assert call_kwargs[1]["variant_name"] == "MyVariant" or \ + (len(call_kwargs[0]) > 2 and call_kwargs[0][2] == "MyVariant") + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_data_cache_config_wired_into_spec(self, mock_endpoint_cls): + """data_cache_config dict is resolved and added to inference_component_spec.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + data_cache_config={"enable_caching": True}, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs[1]["specification"] + assert "DataCacheConfig" in spec + assert spec["DataCacheConfig"]["EnableCaching"] is True + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_base_inference_component_name_wired_into_spec(self, mock_endpoint_cls): + """base_inference_component_name is added to inference_component_spec.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + base_inference_component_name="base-ic-name", + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs[1]["specification"] + assert spec["BaseInferenceComponentName"] == "base-ic-name" + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_container_wired_into_spec(self, mock_endpoint_cls): + """container dict is resolved and added to inference_component_spec.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + container={ + "image": "my-image:latest", + "artifact_url": "s3://bucket/artifact", + "environment": {"KEY": "VALUE"}, + }, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs[1]["specification"] + assert "Container" in spec + assert spec["Container"]["Image"] == "my-image:latest" + assert spec["Container"]["ArtifactUrl"] == "s3://bucket/artifact" + assert spec["Container"]["Environment"] == {"KEY": "VALUE"} + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_no_optional_params_no_extra_keys_in_spec(self, mock_endpoint_cls): + """When no optional IC params are provided, spec has no extra keys.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs[1]["specification"] + assert "DataCacheConfig" not in spec + assert "BaseInferenceComponentName" not in spec + assert "Container" not in spec + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_data_cache_config_typed_object_wired(self, mock_endpoint_cls): + """InferenceComponentDataCacheConfig object is correctly wired.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + config = InferenceComponentDataCacheConfig(enable_caching=True) + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + data_cache_config=config, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs[1]["specification"] + assert spec["DataCacheConfig"]["EnableCaching"] is True From 869474a0255a13aa7ef6fd48bed11494ac4dd517 Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Tue, 14 Apr 2026 11:34:05 -0700 Subject: [PATCH 4/7] fix: address review comments (iteration #3) --- .../src/sagemaker/serve/model_builder.py | 129 ++++++++---------- .../sagemaker/serve/model_builder_utils.py | 10 +- .../sagemaker/serve/test_resolve_ic_params.py | 57 ++++++-- 3 files changed, 113 insertions(+), 83 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 3d0ce25571..92f3011f24 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -2851,62 +2851,6 @@ def _deploy_core_endpoint(self, **kwargs): if self.role_arn is None: raise ValueError("Role can not be null for deploying a model") - routing_config = _resolve_routing_config(routing_config) - - if ( - inference_recommendation_id is not None - or self.inference_recommender_job_results is not None - ): - instance_type, initial_instance_count = self._update_params( - instance_type=instance_type, - initial_instance_count=initial_instance_count, - accelerator_type=accelerator_type, - async_inference_config=async_inference_config, - serverless_inference_config=serverless_inference_config, - explainer_config=explainer_config, - inference_recommendation_id=inference_recommendation_id, - inference_recommender_job_results=self.inference_recommender_job_results, - ) - - is_async = async_inference_config is not None - if is_async and not isinstance(async_inference_config, AsyncInferenceConfig): - raise ValueError("async_inference_config needs to be a AsyncInferenceConfig object") - - is_explainer_enabled = explainer_config is not None - if is_explainer_enabled and not isinstance(explainer_config, ExplainerConfig): - raise ValueError("explainer_config needs to be a ExplainerConfig object") - - is_serverless = serverless_inference_config is not None - if not is_serverless and not (instance_type and initial_instance_count): - raise ValueError( - "Must specify instance type and instance count unless using serverless inference" - ) - - if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig): - raise ValueError( - "serverless_inference_config needs to be a ServerlessInferenceConfig object" - ) - - if self._is_sharded_model: - if endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED: - logger.warning( - "Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - " - "Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints." - ) - endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED - - if self._enable_network_isolation: - raise ValueError( - "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " - "Loading of model requires network access." - ) - - if resources and resources.num_cpus and resources.num_cpus > 0: - logger.warning( - "NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker " - "Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`." - ) - if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED: if update_endpoint: raise ValueError( @@ -2933,10 +2877,14 @@ def _deploy_core_endpoint(self, **kwargs): else: managed_instance_scaling_config["MinInstanceCount"] = initial_instance_count + # Use user-provided variant_name or default to "AllTraffic" + ic_variant_name = kwargs.get("variant_name", "AllTraffic") + if not self.sagemaker_session.endpoint_in_service_or_not(self.endpoint_name): production_variant = session_helper.production_variant( instance_type=instance_type, initial_instance_count=initial_instance_count, + variant_name=ic_variant_name, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, @@ -2986,9 +2934,9 @@ def _deploy_core_endpoint(self, **kwargs): if ic_data_cache_config is not None: resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config) if resolved_cache_config is not None: - cache_dict = {"EnableCaching": resolved_cache_config.enable_caching} - # Forward any additional fields from the shape as they become available - inference_component_spec["DataCacheConfig"] = cache_dict + inference_component_spec["DataCacheConfig"] = { + "EnableCaching": resolved_cache_config.enable_caching + } ic_base_component_name = kwargs.get("base_inference_component_name") if ic_base_component_name is not None: @@ -3015,9 +2963,6 @@ def _deploy_core_endpoint(self, **kwargs): or unique_name_from_base(self.model_name) ) - # Use user-provided variant_name or default to "AllTraffic" - ic_variant_name = kwargs.get("variant_name", "AllTraffic") - # [TODO]: Add endpoint_logging support self.sagemaker_session.create_inference_component( inference_component_name=self.inference_component_name, @@ -3201,6 +3146,34 @@ def _update_inference_component( "StartupParameters": startup_parameters, "ComputeResourceRequirements": compute_rr, } + + # Wire optional IC-level parameters into the update specification + ic_data_cache_config = kwargs.get("data_cache_config") + if ic_data_cache_config is not None: + resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config) + if resolved_cache_config is not None: + inference_component_spec["DataCacheConfig"] = { + "EnableCaching": resolved_cache_config.enable_caching + } + + ic_base_component_name = kwargs.get("base_inference_component_name") + if ic_base_component_name is not None: + inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name + + ic_container = kwargs.get("container") + if ic_container is not None: + resolved_container = self._resolve_container_spec(ic_container) + if resolved_container is not None: + container_dict = {} + if resolved_container.image: + container_dict["Image"] = resolved_container.image + if resolved_container.artifact_url: + container_dict["ArtifactUrl"] = resolved_container.artifact_url + if resolved_container.environment: + container_dict["Environment"] = resolved_container.environment + if container_dict: + inference_component_spec["Container"] = container_dict + runtime_config = {"CopyCount": resource_requirements.copy_count} return self.sagemaker_session.update_inference_component( @@ -4160,6 +4133,7 @@ def deploy( ] = None, custom_orchestrator_instance_type: str = None, custom_orchestrator_initial_instance_count: int = None, + inference_component_name: Optional[str] = None, data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None, base_inference_component_name: Optional[str] = None, container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None, @@ -4197,6 +4171,9 @@ def deploy( orchestrator deployment. (Default: None). custom_orchestrator_initial_instance_count (int, optional): Initial instance count for custom orchestrator deployment. (Default: None). + inference_component_name (str, optional): The name of the inference component + to create. Only used for inference-component-based endpoints. If not specified, + a unique name is generated from the model name. (Default: None). data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional): Data cache configuration for the inference component. Enables caching of model artifacts and container images on instances for faster auto-scaling cold starts. @@ -4213,6 +4190,7 @@ def deploy( variant_name (str, optional): The name of the production variant to deploy to. If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``. (Default: None). + Returns: Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint`` resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode, @@ -4235,15 +4213,16 @@ def deploy( if not hasattr(self, "built_model") and not hasattr(self, "_deployables"): raise ValueError("Model needs to be built before deploying") - # Store IC-level parameters for use in _deploy_core_endpoint + # Centralize variant_name defaulting and always forward IC-level params + kwargs["variant_name"] = variant_name or "AllTraffic" + if inference_component_name is not None: + kwargs["inference_component_name"] = inference_component_name if data_cache_config is not None: kwargs["data_cache_config"] = data_cache_config if base_inference_component_name is not None: kwargs["base_inference_component_name"] = base_inference_component_name if container is not None: kwargs["container"] = container - if variant_name is not None: - kwargs["variant_name"] = variant_name # Handle model customization deployment if self._is_model_customization(): @@ -4401,6 +4380,8 @@ def _deploy_model_customization( initial_instance_count: int = 1, inference_component_name: Optional[str] = None, inference_config: Optional[ResourceRequirements] = None, + variant_name: Optional[str] = None, + data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None, **kwargs, ) -> Endpoint: """Deploy a model customization (fine-tuned) model to an endpoint with inference components. @@ -4442,6 +4423,14 @@ def _deploy_model_customization( # Fetch model package model_package = self._fetch_model_package() + # Resolve variant_name: use provided value or default to "AllTraffic" + effective_variant_name = variant_name or "AllTraffic" + + # Resolve data_cache_config if provided + resolved_data_cache_config = None + if data_cache_config is not None: + resolved_data_cache_config = self._resolve_data_cache_config(data_cache_config) + # Check if endpoint exists is_existing_endpoint = self._does_endpoint_exist(endpoint_name) @@ -4450,7 +4439,7 @@ def _deploy_model_customization( endpoint_config_name=endpoint_name, production_variants=[ ProductionVariant( - variant_name=endpoint_name, + variant_name=effective_variant_name, instance_type=self.instance_type, initial_instance_count=initial_instance_count or 1, ) @@ -4491,6 +4480,7 @@ def _deploy_model_customization( base_ic_spec = InferenceComponentSpecification( model_name=self.built_model.model_name, + data_cache_config=resolved_data_cache_config, ) if inference_config is not None: base_ic_spec.compute_resource_requirements = ( @@ -4507,7 +4497,7 @@ def _deploy_model_customization( InferenceComponent.create( inference_component_name=base_ic_name, endpoint_name=endpoint_name, - variant_name=endpoint_name, + variant_name=effective_variant_name, specification=base_ic_spec, runtime_config=InferenceComponentRuntimeConfig(copy_count=1), tags=[{"key": "Base", "value": base_model_recipe_name}], @@ -4549,7 +4539,8 @@ def _deploy_model_customization( ic_spec = InferenceComponentSpecification( container=InferenceComponentContainerSpecification( image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars - ) + ), + data_cache_config=resolved_data_cache_config, ) if inference_config is not None: @@ -4567,7 +4558,7 @@ def _deploy_model_customization( InferenceComponent.create( inference_component_name=inference_component_name, endpoint_name=endpoint_name, - variant_name=endpoint_name, + variant_name=effective_variant_name, specification=ic_spec, runtime_config=InferenceComponentRuntimeConfig(copy_count=1), ) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 1cc7383ef4..4cf7d56095 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -78,12 +78,12 @@ def build(self): from sagemaker.serve.utils.hardware_detector import _total_inference_model_size_mib from sagemaker.serve.utils.types import ModelServer from sagemaker.core.resources import Model - -# MLflow imports from sagemaker.core.shapes import ( InferenceComponentDataCacheConfig, InferenceComponentContainerSpecification, ) + +# MLflow imports from sagemaker.serve.model_format.mlflow.constants import ( MLFLOW_METADATA_FILE, MLFLOW_MODEL_PATH, @@ -3380,7 +3380,8 @@ def _resolve_data_cache_config( """Resolve data_cache_config to InferenceComponentDataCacheConfig. Args: - data_cache_config: Either a dict with 'enable_caching' key, + data_cache_config: Either a dict with 'enable_caching' key (and any future + fields supported by InferenceComponentDataCacheConfig), an InferenceComponentDataCacheConfig instance, or None. Returns: @@ -3401,6 +3402,9 @@ def _resolve_data_cache_config( "data_cache_config dict must contain the required 'enable_caching' key. " "Example: {'enable_caching': True}" ) + # Pass only 'enable_caching' to avoid Pydantic validation errors + # if the model has extra='forbid'. As new fields are added to + # InferenceComponentDataCacheConfig, add them here. return InferenceComponentDataCacheConfig( enable_caching=data_cache_config["enable_caching"] ) diff --git a/tests/unit/sagemaker/serve/test_resolve_ic_params.py b/tests/unit/sagemaker/serve/test_resolve_ic_params.py index 36ef2bf5e3..a78f626c3a 100644 --- a/tests/unit/sagemaker/serve/test_resolve_ic_params.py +++ b/tests/unit/sagemaker/serve/test_resolve_ic_params.py @@ -24,7 +24,11 @@ class ConcreteUtils(_ModelBuilderUtils): - """Concrete class to test mixin methods.""" + """Concrete class to test mixin methods. + + _ModelBuilderUtils is a mixin that does not define __init__, + so this can be instantiated without arguments. + """ pass @@ -62,16 +66,19 @@ def test_dict_missing_enable_caching_raises(self, utils): utils._resolve_data_cache_config({}) def test_dict_with_extra_keys_still_works(self, utils): - """Extra keys in the input dict are ignored (not forwarded to the Pydantic constructor). + """Extra keys in the input dict are ignored. The resolver only extracts 'enable_caching' from the dict, so extra keys do not cause Pydantic validation errors even if the model forbids extras. + We verify the result has enable_caching=True and does not expose extra_key. """ result = utils._resolve_data_cache_config( {"enable_caching": True, "extra_key": "ignored"} ) assert isinstance(result, InferenceComponentDataCacheConfig) assert result.enable_caching is True + # Verify extra_key is not present on the result object + assert not hasattr(result, "extra_key") or getattr(result, "extra_key", None) is None def test_invalid_type_raises(self, utils): with pytest.raises(ValueError, match="data_cache_config must be a dict"): @@ -230,8 +237,7 @@ def test_variant_name_defaults_to_all_traffic(self, mock_endpoint_cls): # Verify create_inference_component was called with variant_name="AllTraffic" mb.sagemaker_session.create_inference_component.assert_called_once() call_kwargs = mb.sagemaker_session.create_inference_component.call_args - assert call_kwargs[1]["variant_name"] == "AllTraffic" or \ - (len(call_kwargs[0]) > 2 and call_kwargs[0][2] == "AllTraffic") + assert call_kwargs.kwargs["variant_name"] == "AllTraffic" @patch("sagemaker.serve.model_builder.Endpoint") def test_variant_name_custom(self, mock_endpoint_cls): @@ -254,8 +260,7 @@ def test_variant_name_custom(self, mock_endpoint_cls): ) call_kwargs = mb.sagemaker_session.create_inference_component.call_args - assert call_kwargs[1]["variant_name"] == "MyVariant" or \ - (len(call_kwargs[0]) > 2 and call_kwargs[0][2] == "MyVariant") + assert call_kwargs.kwargs["variant_name"] == "MyVariant" @patch("sagemaker.serve.model_builder.Endpoint") def test_data_cache_config_wired_into_spec(self, mock_endpoint_cls): @@ -278,7 +283,7 @@ def test_data_cache_config_wired_into_spec(self, mock_endpoint_cls): ) call_kwargs = mb.sagemaker_session.create_inference_component.call_args - spec = call_kwargs[1]["specification"] + spec = call_kwargs.kwargs["specification"] assert "DataCacheConfig" in spec assert spec["DataCacheConfig"]["EnableCaching"] is True @@ -303,7 +308,7 @@ def test_base_inference_component_name_wired_into_spec(self, mock_endpoint_cls): ) call_kwargs = mb.sagemaker_session.create_inference_component.call_args - spec = call_kwargs[1]["specification"] + spec = call_kwargs.kwargs["specification"] assert spec["BaseInferenceComponentName"] == "base-ic-name" @patch("sagemaker.serve.model_builder.Endpoint") @@ -331,7 +336,7 @@ def test_container_wired_into_spec(self, mock_endpoint_cls): ) call_kwargs = mb.sagemaker_session.create_inference_component.call_args - spec = call_kwargs[1]["specification"] + spec = call_kwargs.kwargs["specification"] assert "Container" in spec assert spec["Container"]["Image"] == "my-image:latest" assert spec["Container"]["ArtifactUrl"] == "s3://bucket/artifact" @@ -357,7 +362,7 @@ def test_no_optional_params_no_extra_keys_in_spec(self, mock_endpoint_cls): ) call_kwargs = mb.sagemaker_session.create_inference_component.call_args - spec = call_kwargs[1]["specification"] + spec = call_kwargs.kwargs["specification"] assert "DataCacheConfig" not in spec assert "BaseInferenceComponentName" not in spec assert "Container" not in spec @@ -384,5 +389,35 @@ def test_data_cache_config_typed_object_wired(self, mock_endpoint_cls): ) call_kwargs = mb.sagemaker_session.create_inference_component.call_args - spec = call_kwargs[1]["specification"] + spec = call_kwargs.kwargs["specification"] assert spec["DataCacheConfig"]["EnableCaching"] is True + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_variant_name_passed_to_production_variant_on_new_endpoint(self, mock_endpoint_cls): + """When creating a new endpoint, variant_name is passed to production_variant.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + # Simulate endpoint does NOT exist yet + mb.sagemaker_session.endpoint_in_service_or_not.return_value = False + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + with patch("sagemaker.serve.model_builder.session_helper.production_variant") as mock_pv: + mock_pv.return_value = {"VariantName": "CustomVariant"} + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + variant_name="CustomVariant", + wait=False, + ) + + # Verify production_variant was called with variant_name="CustomVariant" + mock_pv.assert_called_once() + pv_kwargs = mock_pv.call_args + assert pv_kwargs.kwargs.get("variant_name") == "CustomVariant" or \ + (len(pv_kwargs.args) > 0 and False) # variant_name is always a kwarg From f865a27612ff16f6385c825c8b13630557cb40fd Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Tue, 14 Apr 2026 12:18:56 -0700 Subject: [PATCH 5/7] fix: address review comments (iteration #4) --- .../src/sagemaker/serve/model_builder.py | 106 ++++---- .../sagemaker/serve/test_resolve_ic_params.py | 238 +++++++++++++++++- 2 files changed, 290 insertions(+), 54 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 92f3011f24..d58b218122 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -2702,6 +2702,52 @@ def _wait_for_endpoint( return desc + @staticmethod + def _apply_optional_ic_params(inference_component_spec, **kwargs): + """Apply optional IC-level parameters to an inference component spec dict. + + Wires data_cache_config, base_inference_component_name, and container + into the given inference_component_spec dict. Shared by + _deploy_core_endpoint and _update_inference_component to avoid + code duplication. + + Args: + inference_component_spec (dict): The spec dict to mutate in-place. + **kwargs: May contain data_cache_config, base_inference_component_name, + and container. + """ + from sagemaker.serve.model_builder_utils import _ModelBuilderUtils + + ic_data_cache_config = kwargs.get("data_cache_config") + if ic_data_cache_config is not None: + resolved = _ModelBuilderUtils._resolve_data_cache_config( + None, ic_data_cache_config + ) + if resolved is not None: + inference_component_spec["DataCacheConfig"] = { + "EnableCaching": resolved.enable_caching + } + + ic_base_component_name = kwargs.get("base_inference_component_name") + if ic_base_component_name is not None: + inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name + + ic_container = kwargs.get("container") + if ic_container is not None: + resolved_container = _ModelBuilderUtils._resolve_container_spec( + None, ic_container + ) + if resolved_container is not None: + container_dict = {} + if resolved_container.image: + container_dict["Image"] = resolved_container.image + if resolved_container.artifact_url: + container_dict["ArtifactUrl"] = resolved_container.artifact_url + if resolved_container.environment: + container_dict["Environment"] = resolved_container.environment + if container_dict: + inference_component_spec["Container"] = container_dict + def _deploy_core_endpoint(self, **kwargs): # Extract and update self parameters initial_instance_count = kwargs.get( @@ -2930,31 +2976,7 @@ def _deploy_core_endpoint(self, **kwargs): } # Wire optional IC-level parameters into the specification - ic_data_cache_config = kwargs.get("data_cache_config") - if ic_data_cache_config is not None: - resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config) - if resolved_cache_config is not None: - inference_component_spec["DataCacheConfig"] = { - "EnableCaching": resolved_cache_config.enable_caching - } - - ic_base_component_name = kwargs.get("base_inference_component_name") - if ic_base_component_name is not None: - inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name - - ic_container = kwargs.get("container") - if ic_container is not None: - resolved_container = self._resolve_container_spec(ic_container) - if resolved_container is not None: - container_dict = {} - if resolved_container.image: - container_dict["Image"] = resolved_container.image - if resolved_container.artifact_url: - container_dict["ArtifactUrl"] = resolved_container.artifact_url - if resolved_container.environment: - container_dict["Environment"] = resolved_container.environment - if container_dict: - inference_component_spec["Container"] = container_dict + self._apply_optional_ic_params(inference_component_spec, **kwargs) runtime_config = {"CopyCount": resources.copy_count} self.inference_component_name = ( @@ -3148,31 +3170,7 @@ def _update_inference_component( } # Wire optional IC-level parameters into the update specification - ic_data_cache_config = kwargs.get("data_cache_config") - if ic_data_cache_config is not None: - resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config) - if resolved_cache_config is not None: - inference_component_spec["DataCacheConfig"] = { - "EnableCaching": resolved_cache_config.enable_caching - } - - ic_base_component_name = kwargs.get("base_inference_component_name") - if ic_base_component_name is not None: - inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name - - ic_container = kwargs.get("container") - if ic_container is not None: - resolved_container = self._resolve_container_spec(ic_container) - if resolved_container is not None: - container_dict = {} - if resolved_container.image: - container_dict["Image"] = resolved_container.image - if resolved_container.artifact_url: - container_dict["ArtifactUrl"] = resolved_container.artifact_url - if resolved_container.environment: - container_dict["Environment"] = resolved_container.environment - if container_dict: - inference_component_spec["Container"] = container_dict + self._apply_optional_ic_params(inference_component_spec, **kwargs) runtime_config = {"CopyCount": resource_requirements.copy_count} @@ -4384,6 +4382,9 @@ def _deploy_model_customization( data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None, **kwargs, ) -> Endpoint: + # NOTE: For backward compatibility, model customization deployments + # default variant_name to endpoint_name (not "AllTraffic") when the + # caller does not provide an explicit value. """Deploy a model customization (fine-tuned) model to an endpoint with inference components. This method handles the special deployment flow for fine-tuned models, creating: @@ -4423,8 +4424,9 @@ def _deploy_model_customization( # Fetch model package model_package = self._fetch_model_package() - # Resolve variant_name: use provided value or default to "AllTraffic" - effective_variant_name = variant_name or "AllTraffic" + # Resolve variant_name: preserve backward-compatible default of + # endpoint_name for model customization deployments. + effective_variant_name = variant_name or endpoint_name or "AllTraffic" # Resolve data_cache_config if provided resolved_data_cache_config = None diff --git a/tests/unit/sagemaker/serve/test_resolve_ic_params.py b/tests/unit/sagemaker/serve/test_resolve_ic_params.py index a78f626c3a..6bd8d2c514 100644 --- a/tests/unit/sagemaker/serve/test_resolve_ic_params.py +++ b/tests/unit/sagemaker/serve/test_resolve_ic_params.py @@ -167,6 +167,82 @@ def test_invalid_type_list_raises(self, utils): utils._resolve_container_spec([{"image": "img"}]) +# ============================================================ +# Tests for _apply_optional_ic_params helper +# ============================================================ + +class TestApplyOptionalIcParams: + """Tests for the static helper that wires optional IC params into a spec dict.""" + + def test_no_params_no_mutation(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params(spec) + assert "DataCacheConfig" not in spec + assert "BaseInferenceComponentName" not in spec + assert "Container" not in spec + + def test_data_cache_config_dict(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, data_cache_config={"enable_caching": True} + ) + assert spec["DataCacheConfig"] == {"EnableCaching": True} + + def test_data_cache_config_typed(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + cfg = InferenceComponentDataCacheConfig(enable_caching=False) + ModelBuilder._apply_optional_ic_params(spec, data_cache_config=cfg) + assert spec["DataCacheConfig"] == {"EnableCaching": False} + + def test_base_inference_component_name(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, base_inference_component_name="base-ic" + ) + assert spec["BaseInferenceComponentName"] == "base-ic" + + def test_container_dict(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, + container={ + "image": "img:latest", + "artifact_url": "s3://b/a", + "environment": {"K": "V"}, + }, + ) + assert spec["Container"] == { + "Image": "img:latest", + "ArtifactUrl": "s3://b/a", + "Environment": {"K": "V"}, + } + + def test_container_typed(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + c = InferenceComponentContainerSpecification(image="img") + ModelBuilder._apply_optional_ic_params(spec, container=c) + assert spec["Container"] == {"Image": "img"} + + def test_all_params_together(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, + data_cache_config={"enable_caching": True}, + base_inference_component_name="base", + container={"image": "img"}, + ) + assert spec["DataCacheConfig"] == {"EnableCaching": True} + assert spec["BaseInferenceComponentName"] == "base" + assert spec["Container"] == {"Image": "img"} + + # ============================================================ # Tests for core wiring logic in _deploy_core_endpoint # ============================================================ @@ -419,5 +495,163 @@ def test_variant_name_passed_to_production_variant_on_new_endpoint(self, mock_en # Verify production_variant was called with variant_name="CustomVariant" mock_pv.assert_called_once() pv_kwargs = mock_pv.call_args - assert pv_kwargs.kwargs.get("variant_name") == "CustomVariant" or \ - (len(pv_kwargs.args) > 0 and False) # variant_name is always a kwarg + assert pv_kwargs.kwargs.get("variant_name") == "CustomVariant" + + +# ============================================================ +# Tests for _update_inference_component wiring +# ============================================================ + +class TestUpdateInferenceComponentWiring: + """Tests that _update_inference_component correctly wires optional IC params.""" + + def _make_model_builder(self): + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.model_name = "test-model" + mb.sagemaker_session = MagicMock() + return mb + + def test_update_ic_with_data_cache_config(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component( + "my-ic", resources, data_cache_config={"enable_caching": True} + ) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["DataCacheConfig"] == {"EnableCaching": True} + + def test_update_ic_with_container(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component( + "my-ic", resources, container={"image": "img:v1"} + ) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["Container"] == {"Image": "img:v1"} + + def test_update_ic_with_base_inference_component_name(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component( + "my-ic", resources, base_inference_component_name="base-ic" + ) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["BaseInferenceComponentName"] == "base-ic" + + def test_update_ic_no_optional_params(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component("my-ic", resources) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert "DataCacheConfig" not in spec + assert "BaseInferenceComponentName" not in spec + assert "Container" not in spec + + +# ============================================================ +# Tests for deploy() parameter forwarding +# ============================================================ + +class TestDeployParameterForwarding: + """Tests that deploy() correctly forwards new IC params into kwargs.""" + + def test_deploy_forwards_variant_name_to_kwargs(self): + """deploy() should set kwargs['variant_name'] to the provided value.""" + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.built_model = MagicMock() + mb._deployed = False + mb._is_sharded_model = False + mb.model_name = "test" + mb.instance_type = "ml.m5.large" + mb.endpoint_name = None + mb.mode = None + mb.model_server = None + + # Mock _is_model_customization to return False + mb._is_model_customization = MagicMock(return_value=False) + # Mock _deploy to capture kwargs + captured = {} + + def fake_deploy(**kw): + captured.update(kw) + return MagicMock() + + mb._deploy = fake_deploy + + mb.deploy( + endpoint_name="ep", + instance_type="ml.m5.large", + initial_instance_count=1, + variant_name="MyVariant", + data_cache_config={"enable_caching": True}, + base_inference_component_name="base-ic", + container={"image": "img"}, + ) + + assert captured["variant_name"] == "MyVariant" + assert captured["data_cache_config"] == {"enable_caching": True} + assert captured["base_inference_component_name"] == "base-ic" + assert captured["container"] == {"image": "img"} + + def test_deploy_defaults_variant_name_to_all_traffic(self): + """deploy() should default variant_name to 'AllTraffic' when not provided.""" + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.built_model = MagicMock() + mb._deployed = False + mb._is_sharded_model = False + mb.model_name = "test" + mb.instance_type = "ml.m5.large" + mb.endpoint_name = None + mb.mode = None + mb.model_server = None + mb._is_model_customization = MagicMock(return_value=False) + + captured = {} + + def fake_deploy(**kw): + captured.update(kw) + return MagicMock() + + mb._deploy = fake_deploy + + mb.deploy( + endpoint_name="ep", + instance_type="ml.m5.large", + initial_instance_count=1, + ) + + assert captured["variant_name"] == "AllTraffic" + # Optional params should not be in kwargs when not provided + assert "data_cache_config" not in captured + assert "base_inference_component_name" not in captured + assert "container" not in captured From afcad511a10ad28dfe3aa926ce3719e0f1114fc3 Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Tue, 14 Apr 2026 13:35:57 -0700 Subject: [PATCH 6/7] fix: address review comments (iteration #5) --- .../test_ic_deploy_params_integration.py | 242 ++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py diff --git a/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py b/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py new file mode 100644 index 0000000000..21d41b38e9 --- /dev/null +++ b/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py @@ -0,0 +1,242 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration tests for IC-level deploy parameters (data_cache_config, variant_name).""" +from __future__ import absolute_import + +import json +import uuid +import time +import random +import logging + +import boto3 +import pytest + +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.core.jumpstart.configs import JumpStartConfig +from sagemaker.core.inference_config import ResourceRequirements +from sagemaker.core.resources import ( + Endpoint, + EndpointConfig, + InferenceComponent, +) +from sagemaker.train.configs import Compute + +logger = logging.getLogger(__name__) + +# Use the same JumpStart model as test_jumpstart_integration.py +MODEL_ID = "huggingface-llm-falcon-7b-bf16" + +# Training job for model customization path (same as test_model_customization_deployment.py) +TRAINING_JOB_NAME = "meta-textgeneration-llama-3-2-1b-instruct-sft-20251201172445" + + +def _cleanup_endpoint(endpoint_name, sagemaker_client): + """Delete endpoint, endpoint config, and all inference components.""" + try: + # Delete inference components first + paginator = sagemaker_client.get_paginator("list_inference_components") + for page in paginator.paginate(EndpointNameEquals=endpoint_name): + for ic in page.get("InferenceComponents", []): + ic_name = ic["InferenceComponentName"] + try: + sagemaker_client.delete_inference_component( + InferenceComponentName=ic_name + ) + logger.info("Deleted inference component: %s", ic_name) + except Exception as e: + logger.warning("Failed to delete IC %s: %s", ic_name, e) + except Exception as e: + logger.warning("Failed to list/delete ICs for %s: %s", endpoint_name, e) + + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint_name) + logger.info("Deleted endpoint: %s", endpoint_name) + except Exception as e: + logger.warning("Failed to delete endpoint %s: %s", endpoint_name, e) + + try: + sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_name) + logger.info("Deleted endpoint config: %s", endpoint_name) + except Exception as e: + logger.warning("Failed to delete endpoint config %s: %s", endpoint_name, e) + + +def _cleanup_model(model_name, sagemaker_client): + """Delete a SageMaker model.""" + try: + sagemaker_client.delete_model(ModelName=model_name) + logger.info("Deleted model: %s", model_name) + except Exception as e: + logger.warning("Failed to delete model %s: %s", model_name, e) + + +@pytest.mark.slow_test +def test_deploy_with_data_cache_config_and_variant_name_via_ic_path(): + """Deploy a JumpStart model via the IC-based path with data_cache_config and custom variant_name. + + Verifies: + - The IC was created with DataCacheConfig.EnableCaching == True + - The variant name matches the custom value (not 'AllTraffic') + """ + unique_id = uuid.uuid4().hex[:8] + model_name = f"ic-params-test-model-{unique_id}" + endpoint_name = f"ic-params-test-ep-{unique_id}" + custom_variant = f"Variant-{unique_id}" + + sagemaker_client = boto3.client("sagemaker") + ic_name = None + + try: + # Build + compute = Compute(instance_type="ml.g5.2xlarge") + jumpstart_config = JumpStartConfig(model_id=MODEL_ID) + model_builder = ModelBuilder.from_jumpstart_config( + jumpstart_config=jumpstart_config, compute=compute + ) + core_model = model_builder.build(model_name=model_name) + logger.info("Model created: %s", core_model.model_name) + + # Deploy with IC path (ResourceRequirements triggers IC-based endpoint) + resources = ResourceRequirements( + requests={ + "memory": 8192, + "num_accelerators": 1, + "num_cpus": 2, + "copies": 1, + } + ) + core_endpoint = model_builder.deploy( + endpoint_name=endpoint_name, + initial_instance_count=1, + inference_config=resources, + data_cache_config={"enable_caching": True}, + variant_name=custom_variant, + ) + logger.info("Endpoint created: %s", core_endpoint.endpoint_name) + + # Find the inference component that was created + ic_name = model_builder.inference_component_name + assert ic_name is not None, "inference_component_name should be set after deploy" + + # Describe the inference component via boto3 + ic_desc = sagemaker_client.describe_inference_component( + InferenceComponentName=ic_name + ) + + # Verify DataCacheConfig.EnableCaching == True + spec = ic_desc.get("Specification", {}) + data_cache = spec.get("DataCacheConfig", {}) + assert data_cache.get("EnableCaching") is True, ( + f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}" + ) + + # Verify variant name matches custom value + actual_variant = ic_desc.get("VariantName") + assert actual_variant == custom_variant, ( + f"Expected VariantName='{custom_variant}', got '{actual_variant}'" + ) + + logger.info( + "Test passed: IC '%s' has DataCacheConfig.EnableCaching=True and VariantName='%s'", + ic_name, + custom_variant, + ) + + finally: + _cleanup_endpoint(endpoint_name, sagemaker_client) + _cleanup_model(model_name, sagemaker_client) + + +@pytest.mark.slow_test +def test_deploy_with_data_cache_config_via_model_customization_path(): + """Deploy a fine-tuned model via _deploy_model_customization with data_cache_config. + + Verifies: + - The IC was created with DataCacheConfig.EnableCaching == True + - The variant_name defaults to endpoint_name (backward compat) when not explicitly provided + """ + from sagemaker.core.resources import TrainingJob + + unique_id = uuid.uuid4().hex[:8] + model_name = f"ic-mc-test-model-{unique_id}" + endpoint_name = f"ic-mc-test-ep-{unique_id}" + + sagemaker_client = boto3.client("sagemaker") + + try: + training_job = TrainingJob.get(training_job_name=TRAINING_JOB_NAME) + model_builder = ModelBuilder( + model=training_job, instance_type="ml.g5.4xlarge" + ) + model_builder.accept_eula = True + core_model = model_builder.build(model_name=model_name) + logger.info("Model created: %s", core_model.model_name) + + # Deploy with data_cache_config but WITHOUT explicit variant_name + # so it should default to endpoint_name for model customization path + endpoint = model_builder.deploy( + endpoint_name=endpoint_name, + initial_instance_count=1, + data_cache_config={"enable_caching": True}, + ) + logger.info("Endpoint created: %s", endpoint.endpoint_name) + + # Find inference components on this endpoint + paginator = sagemaker_client.get_paginator("list_inference_components") + ic_names = [] + for page in paginator.paginate(EndpointNameEquals=endpoint_name): + for ic in page.get("InferenceComponents", []): + ic_names.append(ic["InferenceComponentName"]) + + assert len(ic_names) > 0, ( + f"Expected at least one inference component on endpoint '{endpoint_name}'" + ) + + # Check the first (or base) IC for DataCacheConfig + # For LORA, the base IC should have data_cache_config; for non-LORA, the single IC. + peft_type = model_builder._fetch_peft() + if peft_type == "LORA": + # Base IC is named -inference-component + base_ic_name = f"{endpoint_name}-inference-component" + else: + base_ic_name = f"{endpoint_name}-inference-component" + + ic_desc = sagemaker_client.describe_inference_component( + InferenceComponentName=base_ic_name + ) + + # Verify DataCacheConfig.EnableCaching == True + spec = ic_desc.get("Specification", {}) + data_cache = spec.get("DataCacheConfig", {}) + assert data_cache.get("EnableCaching") is True, ( + f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}" + ) + + # Verify variant_name defaults to endpoint_name (backward compat) + actual_variant = ic_desc.get("VariantName") + assert actual_variant == endpoint_name, ( + f"Expected VariantName='{endpoint_name}' (backward compat default), " + f"got '{actual_variant}'" + ) + + logger.info( + "Test passed: IC '%s' has DataCacheConfig.EnableCaching=True " + "and VariantName='%s' (backward compat default)", + base_ic_name, + endpoint_name, + ) + + finally: + _cleanup_endpoint(endpoint_name, sagemaker_client) + _cleanup_model(model_name, sagemaker_client) From 72724446aa43ce57fc67129bea283fb831bb16d4 Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Wed, 15 Apr 2026 13:19:14 -0700 Subject: [PATCH 7/7] fix: address review comments (iteration #1) --- .../src/sagemaker/serve/model_builder.py | 8 +- .../test_ic_deploy_params_integration.py | 84 ------------------- .../sagemaker/serve/test_resolve_ic_params.py | 48 ++++++++++- 3 files changed, 51 insertions(+), 89 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index d58b218122..20062ae62e 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -4211,8 +4211,12 @@ def deploy( if not hasattr(self, "built_model") and not hasattr(self, "_deployables"): raise ValueError("Model needs to be built before deploying") - # Centralize variant_name defaulting and always forward IC-level params - kwargs["variant_name"] = variant_name or "AllTraffic" + # Only forward variant_name when explicitly provided by the caller. + # Each downstream path has its own default: + # - _deploy_core_endpoint defaults to "AllTraffic" + # - _deploy_model_customization defaults to endpoint_name + if variant_name is not None: + kwargs["variant_name"] = variant_name if inference_component_name is not None: kwargs["inference_component_name"] = inference_component_name if data_cache_config is not None: diff --git a/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py b/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py index 21d41b38e9..ade859ddb4 100644 --- a/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py +++ b/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py @@ -37,9 +37,6 @@ # Use the same JumpStart model as test_jumpstart_integration.py MODEL_ID = "huggingface-llm-falcon-7b-bf16" -# Training job for model customization path (same as test_model_customization_deployment.py) -TRAINING_JOB_NAME = "meta-textgeneration-llama-3-2-1b-instruct-sft-20251201172445" - def _cleanup_endpoint(endpoint_name, sagemaker_client): """Delete endpoint, endpoint config, and all inference components.""" @@ -158,85 +155,4 @@ def test_deploy_with_data_cache_config_and_variant_name_via_ic_path(): _cleanup_model(model_name, sagemaker_client) -@pytest.mark.slow_test -def test_deploy_with_data_cache_config_via_model_customization_path(): - """Deploy a fine-tuned model via _deploy_model_customization with data_cache_config. - - Verifies: - - The IC was created with DataCacheConfig.EnableCaching == True - - The variant_name defaults to endpoint_name (backward compat) when not explicitly provided - """ - from sagemaker.core.resources import TrainingJob - - unique_id = uuid.uuid4().hex[:8] - model_name = f"ic-mc-test-model-{unique_id}" - endpoint_name = f"ic-mc-test-ep-{unique_id}" - - sagemaker_client = boto3.client("sagemaker") - - try: - training_job = TrainingJob.get(training_job_name=TRAINING_JOB_NAME) - model_builder = ModelBuilder( - model=training_job, instance_type="ml.g5.4xlarge" - ) - model_builder.accept_eula = True - core_model = model_builder.build(model_name=model_name) - logger.info("Model created: %s", core_model.model_name) - - # Deploy with data_cache_config but WITHOUT explicit variant_name - # so it should default to endpoint_name for model customization path - endpoint = model_builder.deploy( - endpoint_name=endpoint_name, - initial_instance_count=1, - data_cache_config={"enable_caching": True}, - ) - logger.info("Endpoint created: %s", endpoint.endpoint_name) - - # Find inference components on this endpoint - paginator = sagemaker_client.get_paginator("list_inference_components") - ic_names = [] - for page in paginator.paginate(EndpointNameEquals=endpoint_name): - for ic in page.get("InferenceComponents", []): - ic_names.append(ic["InferenceComponentName"]) - - assert len(ic_names) > 0, ( - f"Expected at least one inference component on endpoint '{endpoint_name}'" - ) - - # Check the first (or base) IC for DataCacheConfig - # For LORA, the base IC should have data_cache_config; for non-LORA, the single IC. - peft_type = model_builder._fetch_peft() - if peft_type == "LORA": - # Base IC is named -inference-component - base_ic_name = f"{endpoint_name}-inference-component" - else: - base_ic_name = f"{endpoint_name}-inference-component" - - ic_desc = sagemaker_client.describe_inference_component( - InferenceComponentName=base_ic_name - ) - # Verify DataCacheConfig.EnableCaching == True - spec = ic_desc.get("Specification", {}) - data_cache = spec.get("DataCacheConfig", {}) - assert data_cache.get("EnableCaching") is True, ( - f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}" - ) - - # Verify variant_name defaults to endpoint_name (backward compat) - actual_variant = ic_desc.get("VariantName") - assert actual_variant == endpoint_name, ( - f"Expected VariantName='{endpoint_name}' (backward compat default), " - f"got '{actual_variant}'" - ) - - logger.info( - "Test passed: IC '%s' has DataCacheConfig.EnableCaching=True " - "and VariantName='%s' (backward compat default)", - base_ic_name, - endpoint_name, - ) - - finally: - _cleanup_endpoint(endpoint_name, sagemaker_client) - _cleanup_model(model_name, sagemaker_client) diff --git a/tests/unit/sagemaker/serve/test_resolve_ic_params.py b/tests/unit/sagemaker/serve/test_resolve_ic_params.py index 6bd8d2c514..ceb697bac0 100644 --- a/tests/unit/sagemaker/serve/test_resolve_ic_params.py +++ b/tests/unit/sagemaker/serve/test_resolve_ic_params.py @@ -621,8 +621,13 @@ def fake_deploy(**kw): assert captured["base_inference_component_name"] == "base-ic" assert captured["container"] == {"image": "img"} - def test_deploy_defaults_variant_name_to_all_traffic(self): - """deploy() should default variant_name to 'AllTraffic' when not provided.""" + def test_deploy_does_not_set_variant_name_when_not_provided(self): + """deploy() should NOT set variant_name in kwargs when not provided. + + This allows downstream methods to use their own defaults: + - _deploy_core_endpoint defaults to 'AllTraffic' + - _deploy_model_customization defaults to endpoint_name + """ from sagemaker.serve.model_builder import ModelBuilder mb = object.__new__(ModelBuilder) @@ -650,8 +655,45 @@ def fake_deploy(**kw): initial_instance_count=1, ) - assert captured["variant_name"] == "AllTraffic" + # variant_name should NOT be in kwargs when not explicitly provided + assert "variant_name" not in captured # Optional params should not be in kwargs when not provided assert "data_cache_config" not in captured assert "base_inference_component_name" not in captured assert "container" not in captured + + def test_deploy_forwards_variant_name_none_is_not_forwarded(self): + """deploy(variant_name=None) should NOT forward variant_name. + + None is the default, so it should behave the same as not providing it. + """ + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.built_model = MagicMock() + mb._deployed = False + mb._is_sharded_model = False + mb.model_name = "test" + mb.instance_type = "ml.m5.large" + mb.endpoint_name = None + mb.mode = None + mb.model_server = None + mb._is_model_customization = MagicMock(return_value=False) + + captured = {} + + def fake_deploy(**kw): + captured.update(kw) + return MagicMock() + + mb._deploy = fake_deploy + + mb.deploy( + endpoint_name="ep", + instance_type="ml.m5.large", + initial_instance_count=1, + variant_name=None, + ) + + # variant_name=None should not be forwarded + assert "variant_name" not in captured