diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 7c7af2defc..20062ae62e 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -45,6 +45,8 @@ ModelLifeCycle, DriftCheckBaselines, InferenceComponentComputeResourceRequirements, + InferenceComponentDataCacheConfig, + InferenceComponentContainerSpecification, ) from sagemaker.core.resources import ( ModelPackage, @@ -2700,6 +2702,52 @@ def _wait_for_endpoint( return desc + @staticmethod + def _apply_optional_ic_params(inference_component_spec, **kwargs): + """Apply optional IC-level parameters to an inference component spec dict. + + Wires data_cache_config, base_inference_component_name, and container + into the given inference_component_spec dict. Shared by + _deploy_core_endpoint and _update_inference_component to avoid + code duplication. + + Args: + inference_component_spec (dict): The spec dict to mutate in-place. + **kwargs: May contain data_cache_config, base_inference_component_name, + and container. + """ + from sagemaker.serve.model_builder_utils import _ModelBuilderUtils + + ic_data_cache_config = kwargs.get("data_cache_config") + if ic_data_cache_config is not None: + resolved = _ModelBuilderUtils._resolve_data_cache_config( + None, ic_data_cache_config + ) + if resolved is not None: + inference_component_spec["DataCacheConfig"] = { + "EnableCaching": resolved.enable_caching + } + + ic_base_component_name = kwargs.get("base_inference_component_name") + if ic_base_component_name is not None: + inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name + + ic_container = kwargs.get("container") + if ic_container is not None: + resolved_container = _ModelBuilderUtils._resolve_container_spec( + None, ic_container + ) + if resolved_container is not None: + container_dict = {} + if resolved_container.image: + container_dict["Image"] = resolved_container.image + if resolved_container.artifact_url: + container_dict["ArtifactUrl"] = resolved_container.artifact_url + if resolved_container.environment: + container_dict["Environment"] = resolved_container.environment + if container_dict: + inference_component_spec["Container"] = container_dict + def _deploy_core_endpoint(self, **kwargs): # Extract and update self parameters initial_instance_count = kwargs.get( @@ -2849,62 +2897,6 @@ def _deploy_core_endpoint(self, **kwargs): if self.role_arn is None: raise ValueError("Role can not be null for deploying a model") - routing_config = _resolve_routing_config(routing_config) - - if ( - inference_recommendation_id is not None - or self.inference_recommender_job_results is not None - ): - instance_type, initial_instance_count = self._update_params( - instance_type=instance_type, - initial_instance_count=initial_instance_count, - accelerator_type=accelerator_type, - async_inference_config=async_inference_config, - serverless_inference_config=serverless_inference_config, - explainer_config=explainer_config, - inference_recommendation_id=inference_recommendation_id, - inference_recommender_job_results=self.inference_recommender_job_results, - ) - - is_async = async_inference_config is not None - if is_async and not isinstance(async_inference_config, AsyncInferenceConfig): - raise ValueError("async_inference_config needs to be a AsyncInferenceConfig object") - - is_explainer_enabled = explainer_config is not None - if is_explainer_enabled and not isinstance(explainer_config, ExplainerConfig): - raise ValueError("explainer_config needs to be a ExplainerConfig object") - - is_serverless = serverless_inference_config is not None - if not is_serverless and not (instance_type and initial_instance_count): - raise ValueError( - "Must specify instance type and instance count unless using serverless inference" - ) - - if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig): - raise ValueError( - "serverless_inference_config needs to be a ServerlessInferenceConfig object" - ) - - if self._is_sharded_model: - if endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED: - logger.warning( - "Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - " - "Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints." - ) - endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED - - if self._enable_network_isolation: - raise ValueError( - "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " - "Loading of model requires network access." - ) - - if resources and resources.num_cpus and resources.num_cpus > 0: - logger.warning( - "NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker " - "Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`." - ) - if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED: if update_endpoint: raise ValueError( @@ -2931,10 +2923,14 @@ def _deploy_core_endpoint(self, **kwargs): else: managed_instance_scaling_config["MinInstanceCount"] = initial_instance_count + # Use user-provided variant_name or default to "AllTraffic" + ic_variant_name = kwargs.get("variant_name", "AllTraffic") + if not self.sagemaker_session.endpoint_in_service_or_not(self.endpoint_name): production_variant = session_helper.production_variant( instance_type=instance_type, initial_instance_count=initial_instance_count, + variant_name=ic_variant_name, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, @@ -2978,6 +2974,10 @@ def _deploy_core_endpoint(self, **kwargs): "StartupParameters": startup_parameters, "ComputeResourceRequirements": resources.get_compute_resource_requirements(), } + + # Wire optional IC-level parameters into the specification + self._apply_optional_ic_params(inference_component_spec, **kwargs) + runtime_config = {"CopyCount": resources.copy_count} self.inference_component_name = ( inference_component_name @@ -2989,7 +2989,7 @@ def _deploy_core_endpoint(self, **kwargs): self.sagemaker_session.create_inference_component( inference_component_name=self.inference_component_name, endpoint_name=self.endpoint_name, - variant_name="AllTraffic", # default variant name + variant_name=ic_variant_name, specification=inference_component_spec, runtime_config=runtime_config, tags=tags, @@ -3168,6 +3168,10 @@ def _update_inference_component( "StartupParameters": startup_parameters, "ComputeResourceRequirements": compute_rr, } + + # Wire optional IC-level parameters into the update specification + self._apply_optional_ic_params(inference_component_spec, **kwargs) + runtime_config = {"CopyCount": resource_requirements.copy_count} return self.sagemaker_session.update_inference_component( @@ -4127,6 +4131,11 @@ def deploy( ] = None, custom_orchestrator_instance_type: str = None, custom_orchestrator_initial_instance_count: int = None, + inference_component_name: Optional[str] = None, + data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None, + base_inference_component_name: Optional[str] = None, + container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None, + variant_name: Optional[str] = None, **kwargs, ) -> Union[Endpoint, LocalEndpoint, Transformer]: """Deploy the built model to an ``Endpoint``. @@ -4160,6 +4169,26 @@ def deploy( orchestrator deployment. (Default: None). custom_orchestrator_initial_instance_count (int, optional): Initial instance count for custom orchestrator deployment. (Default: None). + inference_component_name (str, optional): The name of the inference component + to create. Only used for inference-component-based endpoints. If not specified, + a unique name is generated from the model name. (Default: None). + data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional): + Data cache configuration for the inference component. Enables caching of model + artifacts and container images on instances for faster auto-scaling cold starts. + Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an + InferenceComponentDataCacheConfig instance. (Default: None). + base_inference_component_name (str, optional): Name of the base inference component + for adapter deployments (e.g., LoRA adapters attached to a base model). + (Default: None). + container (Union[InferenceComponentContainerSpecification, dict], optional): + Custom container specification for the inference component, including image URI, + artifact URL, and environment variables. Can be a dict with keys 'image', + 'artifact_url', 'environment' or an InferenceComponentContainerSpecification + instance. (Default: None). + variant_name (str, optional): The name of the production variant to deploy to. + If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``. + (Default: None). + Returns: Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint`` resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode, @@ -4182,6 +4211,21 @@ def deploy( if not hasattr(self, "built_model") and not hasattr(self, "_deployables"): raise ValueError("Model needs to be built before deploying") + # Only forward variant_name when explicitly provided by the caller. + # Each downstream path has its own default: + # - _deploy_core_endpoint defaults to "AllTraffic" + # - _deploy_model_customization defaults to endpoint_name + if variant_name is not None: + kwargs["variant_name"] = variant_name + if inference_component_name is not None: + kwargs["inference_component_name"] = inference_component_name + if data_cache_config is not None: + kwargs["data_cache_config"] = data_cache_config + if base_inference_component_name is not None: + kwargs["base_inference_component_name"] = base_inference_component_name + if container is not None: + kwargs["container"] = container + # Handle model customization deployment if self._is_model_customization(): logger.info("Deploying Model Customization model") @@ -4338,8 +4382,13 @@ def _deploy_model_customization( initial_instance_count: int = 1, inference_component_name: Optional[str] = None, inference_config: Optional[ResourceRequirements] = None, + variant_name: Optional[str] = None, + data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None, **kwargs, ) -> Endpoint: + # NOTE: For backward compatibility, model customization deployments + # default variant_name to endpoint_name (not "AllTraffic") when the + # caller does not provide an explicit value. """Deploy a model customization (fine-tuned) model to an endpoint with inference components. This method handles the special deployment flow for fine-tuned models, creating: @@ -4379,6 +4428,15 @@ def _deploy_model_customization( # Fetch model package model_package = self._fetch_model_package() + # Resolve variant_name: preserve backward-compatible default of + # endpoint_name for model customization deployments. + effective_variant_name = variant_name or endpoint_name or "AllTraffic" + + # Resolve data_cache_config if provided + resolved_data_cache_config = None + if data_cache_config is not None: + resolved_data_cache_config = self._resolve_data_cache_config(data_cache_config) + # Check if endpoint exists is_existing_endpoint = self._does_endpoint_exist(endpoint_name) @@ -4387,7 +4445,7 @@ def _deploy_model_customization( endpoint_config_name=endpoint_name, production_variants=[ ProductionVariant( - variant_name=endpoint_name, + variant_name=effective_variant_name, instance_type=self.instance_type, initial_instance_count=initial_instance_count or 1, ) @@ -4428,6 +4486,7 @@ def _deploy_model_customization( base_ic_spec = InferenceComponentSpecification( model_name=self.built_model.model_name, + data_cache_config=resolved_data_cache_config, ) if inference_config is not None: base_ic_spec.compute_resource_requirements = ( @@ -4444,7 +4503,7 @@ def _deploy_model_customization( InferenceComponent.create( inference_component_name=base_ic_name, endpoint_name=endpoint_name, - variant_name=endpoint_name, + variant_name=effective_variant_name, specification=base_ic_spec, runtime_config=InferenceComponentRuntimeConfig(copy_count=1), tags=[{"key": "Base", "value": base_model_recipe_name}], @@ -4486,7 +4545,8 @@ def _deploy_model_customization( ic_spec = InferenceComponentSpecification( container=InferenceComponentContainerSpecification( image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars - ) + ), + data_cache_config=resolved_data_cache_config, ) if inference_config is not None: @@ -4504,7 +4564,7 @@ def _deploy_model_customization( InferenceComponent.create( inference_component_name=inference_component_name, endpoint_name=endpoint_name, - variant_name=endpoint_name, + variant_name=effective_variant_name, specification=ic_spec, runtime_config=InferenceComponentRuntimeConfig(copy_count=1), ) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 56f3070346..4cf7d56095 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -78,6 +78,10 @@ def build(self): from sagemaker.serve.utils.hardware_detector import _total_inference_model_size_mib from sagemaker.serve.utils.types import ModelServer from sagemaker.core.resources import Model +from sagemaker.core.shapes import ( + InferenceComponentDataCacheConfig, + InferenceComponentContainerSpecification, +) # MLflow imports from sagemaker.serve.model_format.mlflow.constants import ( @@ -3369,6 +3373,80 @@ def _extract_speculative_draft_model_provider( return "auto" + def _resolve_data_cache_config( + self, + data_cache_config: Union[InferenceComponentDataCacheConfig, Dict[str, Any], None], + ) -> Optional[InferenceComponentDataCacheConfig]: + """Resolve data_cache_config to InferenceComponentDataCacheConfig. + + Args: + data_cache_config: Either a dict with 'enable_caching' key (and any future + fields supported by InferenceComponentDataCacheConfig), + an InferenceComponentDataCacheConfig instance, or None. + + Returns: + InferenceComponentDataCacheConfig or None. + + Raises: + ValueError: If data_cache_config is an unsupported type or dict + is missing the required 'enable_caching' key. + """ + if data_cache_config is None: + return None + + if isinstance(data_cache_config, InferenceComponentDataCacheConfig): + return data_cache_config + elif isinstance(data_cache_config, dict): + if "enable_caching" not in data_cache_config: + raise ValueError( + "data_cache_config dict must contain the required 'enable_caching' key. " + "Example: {'enable_caching': True}" + ) + # Pass only 'enable_caching' to avoid Pydantic validation errors + # if the model has extra='forbid'. As new fields are added to + # InferenceComponentDataCacheConfig, add them here. + return InferenceComponentDataCacheConfig( + enable_caching=data_cache_config["enable_caching"] + ) + else: + raise ValueError( + f"data_cache_config must be a dict with 'enable_caching' key or an " + f"InferenceComponentDataCacheConfig instance, got {type(data_cache_config)}" + ) + + def _resolve_container_spec( + self, + container: Union[InferenceComponentContainerSpecification, Dict[str, Any], None], + ) -> Optional[InferenceComponentContainerSpecification]: + """Resolve container to InferenceComponentContainerSpecification. + + Args: + container: Either a dict with container config keys (image, artifact_url, + environment), an InferenceComponentContainerSpecification instance, or None. + + Returns: + InferenceComponentContainerSpecification or None. + + Raises: + ValueError: If container is an unsupported type. + """ + if container is None: + return None + + if isinstance(container, InferenceComponentContainerSpecification): + return container + elif isinstance(container, dict): + # Only pass known keys to avoid Pydantic validation errors + # if the model has extra='forbid' configured + known_keys = {"image", "artifact_url", "environment"} + filtered = {k: v for k, v in container.items() if k in known_keys} + return InferenceComponentContainerSpecification(**filtered) + else: + raise ValueError( + f"container must be a dict or an InferenceComponentContainerSpecification " + f"instance, got {type(container)}" + ) + def get_huggingface_model_metadata( self, model_id: str, hf_hub_token: Optional[str] = None ) -> dict: diff --git a/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py b/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py new file mode 100644 index 0000000000..ade859ddb4 --- /dev/null +++ b/sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py @@ -0,0 +1,158 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration tests for IC-level deploy parameters (data_cache_config, variant_name).""" +from __future__ import absolute_import + +import json +import uuid +import time +import random +import logging + +import boto3 +import pytest + +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.core.jumpstart.configs import JumpStartConfig +from sagemaker.core.inference_config import ResourceRequirements +from sagemaker.core.resources import ( + Endpoint, + EndpointConfig, + InferenceComponent, +) +from sagemaker.train.configs import Compute + +logger = logging.getLogger(__name__) + +# Use the same JumpStart model as test_jumpstart_integration.py +MODEL_ID = "huggingface-llm-falcon-7b-bf16" + + +def _cleanup_endpoint(endpoint_name, sagemaker_client): + """Delete endpoint, endpoint config, and all inference components.""" + try: + # Delete inference components first + paginator = sagemaker_client.get_paginator("list_inference_components") + for page in paginator.paginate(EndpointNameEquals=endpoint_name): + for ic in page.get("InferenceComponents", []): + ic_name = ic["InferenceComponentName"] + try: + sagemaker_client.delete_inference_component( + InferenceComponentName=ic_name + ) + logger.info("Deleted inference component: %s", ic_name) + except Exception as e: + logger.warning("Failed to delete IC %s: %s", ic_name, e) + except Exception as e: + logger.warning("Failed to list/delete ICs for %s: %s", endpoint_name, e) + + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint_name) + logger.info("Deleted endpoint: %s", endpoint_name) + except Exception as e: + logger.warning("Failed to delete endpoint %s: %s", endpoint_name, e) + + try: + sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_name) + logger.info("Deleted endpoint config: %s", endpoint_name) + except Exception as e: + logger.warning("Failed to delete endpoint config %s: %s", endpoint_name, e) + + +def _cleanup_model(model_name, sagemaker_client): + """Delete a SageMaker model.""" + try: + sagemaker_client.delete_model(ModelName=model_name) + logger.info("Deleted model: %s", model_name) + except Exception as e: + logger.warning("Failed to delete model %s: %s", model_name, e) + + +@pytest.mark.slow_test +def test_deploy_with_data_cache_config_and_variant_name_via_ic_path(): + """Deploy a JumpStart model via the IC-based path with data_cache_config and custom variant_name. + + Verifies: + - The IC was created with DataCacheConfig.EnableCaching == True + - The variant name matches the custom value (not 'AllTraffic') + """ + unique_id = uuid.uuid4().hex[:8] + model_name = f"ic-params-test-model-{unique_id}" + endpoint_name = f"ic-params-test-ep-{unique_id}" + custom_variant = f"Variant-{unique_id}" + + sagemaker_client = boto3.client("sagemaker") + ic_name = None + + try: + # Build + compute = Compute(instance_type="ml.g5.2xlarge") + jumpstart_config = JumpStartConfig(model_id=MODEL_ID) + model_builder = ModelBuilder.from_jumpstart_config( + jumpstart_config=jumpstart_config, compute=compute + ) + core_model = model_builder.build(model_name=model_name) + logger.info("Model created: %s", core_model.model_name) + + # Deploy with IC path (ResourceRequirements triggers IC-based endpoint) + resources = ResourceRequirements( + requests={ + "memory": 8192, + "num_accelerators": 1, + "num_cpus": 2, + "copies": 1, + } + ) + core_endpoint = model_builder.deploy( + endpoint_name=endpoint_name, + initial_instance_count=1, + inference_config=resources, + data_cache_config={"enable_caching": True}, + variant_name=custom_variant, + ) + logger.info("Endpoint created: %s", core_endpoint.endpoint_name) + + # Find the inference component that was created + ic_name = model_builder.inference_component_name + assert ic_name is not None, "inference_component_name should be set after deploy" + + # Describe the inference component via boto3 + ic_desc = sagemaker_client.describe_inference_component( + InferenceComponentName=ic_name + ) + + # Verify DataCacheConfig.EnableCaching == True + spec = ic_desc.get("Specification", {}) + data_cache = spec.get("DataCacheConfig", {}) + assert data_cache.get("EnableCaching") is True, ( + f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}" + ) + + # Verify variant name matches custom value + actual_variant = ic_desc.get("VariantName") + assert actual_variant == custom_variant, ( + f"Expected VariantName='{custom_variant}', got '{actual_variant}'" + ) + + logger.info( + "Test passed: IC '%s' has DataCacheConfig.EnableCaching=True and VariantName='%s'", + ic_name, + custom_variant, + ) + + finally: + _cleanup_endpoint(endpoint_name, sagemaker_client) + _cleanup_model(model_name, sagemaker_client) + + + diff --git a/tests/unit/sagemaker/serve/test_resolve_ic_params.py b/tests/unit/sagemaker/serve/test_resolve_ic_params.py new file mode 100644 index 0000000000..ceb697bac0 --- /dev/null +++ b/tests/unit/sagemaker/serve/test_resolve_ic_params.py @@ -0,0 +1,699 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for IC parameter resolvers and wiring logic.""" +from __future__ import absolute_import + +import pytest +from unittest.mock import MagicMock, patch, ANY + +from sagemaker.core.shapes import ( + InferenceComponentDataCacheConfig, + InferenceComponentContainerSpecification, +) +from sagemaker.serve.model_builder_utils import _ModelBuilderUtils + + +class ConcreteUtils(_ModelBuilderUtils): + """Concrete class to test mixin methods. + + _ModelBuilderUtils is a mixin that does not define __init__, + so this can be instantiated without arguments. + """ + pass + + +@pytest.fixture +def utils(): + return ConcreteUtils() + + +# ============================================================ +# Tests for _resolve_data_cache_config +# ============================================================ + +class TestResolveDataCacheConfig: + def test_none_returns_none(self, utils): + assert utils._resolve_data_cache_config(None) is None + + def test_already_typed_passthrough(self, utils): + config = InferenceComponentDataCacheConfig(enable_caching=True) + result = utils._resolve_data_cache_config(config) + assert result is config + assert result.enable_caching is True + + def test_dict_with_enable_caching_true(self, utils): + result = utils._resolve_data_cache_config({"enable_caching": True}) + assert isinstance(result, InferenceComponentDataCacheConfig) + assert result.enable_caching is True + + def test_dict_with_enable_caching_false(self, utils): + result = utils._resolve_data_cache_config({"enable_caching": False}) + assert isinstance(result, InferenceComponentDataCacheConfig) + assert result.enable_caching is False + + def test_dict_missing_enable_caching_raises(self, utils): + with pytest.raises(ValueError, match="must contain the required 'enable_caching' key"): + utils._resolve_data_cache_config({}) + + def test_dict_with_extra_keys_still_works(self, utils): + """Extra keys in the input dict are ignored. + + The resolver only extracts 'enable_caching' from the dict, so extra keys + do not cause Pydantic validation errors even if the model forbids extras. + We verify the result has enable_caching=True and does not expose extra_key. + """ + result = utils._resolve_data_cache_config( + {"enable_caching": True, "extra_key": "ignored"} + ) + assert isinstance(result, InferenceComponentDataCacheConfig) + assert result.enable_caching is True + # Verify extra_key is not present on the result object + assert not hasattr(result, "extra_key") or getattr(result, "extra_key", None) is None + + def test_invalid_type_raises(self, utils): + with pytest.raises(ValueError, match="data_cache_config must be a dict"): + utils._resolve_data_cache_config("invalid") + + def test_invalid_type_int_raises(self, utils): + with pytest.raises(ValueError, match="data_cache_config must be a dict"): + utils._resolve_data_cache_config(42) + + def test_invalid_type_list_raises(self, utils): + with pytest.raises(ValueError, match="data_cache_config must be a dict"): + utils._resolve_data_cache_config([True]) + + +# ============================================================ +# Tests for _resolve_container_spec +# ============================================================ + +class TestResolveContainerSpec: + def test_none_returns_none(self, utils): + assert utils._resolve_container_spec(None) is None + + def test_already_typed_passthrough(self, utils): + spec = InferenceComponentContainerSpecification( + image="my-image:latest", + artifact_url="s3://bucket/artifact", + environment={"KEY": "VALUE"}, + ) + result = utils._resolve_container_spec(spec) + assert result is spec + + def test_dict_full(self, utils): + result = utils._resolve_container_spec({ + "image": "my-image:latest", + "artifact_url": "s3://bucket/artifact", + "environment": {"KEY": "VALUE"}, + }) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.image == "my-image:latest" + assert result.artifact_url == "s3://bucket/artifact" + assert result.environment == {"KEY": "VALUE"} + + def test_dict_image_only(self, utils): + result = utils._resolve_container_spec({"image": "my-image:latest"}) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.image == "my-image:latest" + + def test_dict_artifact_url_only(self, utils): + result = utils._resolve_container_spec({"artifact_url": "s3://bucket/model.tar.gz"}) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.artifact_url == "s3://bucket/model.tar.gz" + + def test_dict_environment_only(self, utils): + result = utils._resolve_container_spec({"environment": {"A": "B"}}) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.environment == {"A": "B"} + + def test_dict_empty(self, utils): + """Empty dict creates a spec with no fields set.""" + result = utils._resolve_container_spec({}) + assert isinstance(result, InferenceComponentContainerSpecification) + + def test_dict_with_extra_keys(self, utils): + """Extra keys are filtered out before passing to the Pydantic constructor. + + This ensures compatibility even if InferenceComponentContainerSpecification + has extra='forbid' in its Pydantic model config. + """ + result = utils._resolve_container_spec({ + "image": "img", + "unknown_key": "ignored", + }) + assert isinstance(result, InferenceComponentContainerSpecification) + assert result.image == "img" + + def test_invalid_type_raises(self, utils): + with pytest.raises(ValueError, match="container must be a dict"): + utils._resolve_container_spec("invalid") + + def test_invalid_type_int_raises(self, utils): + with pytest.raises(ValueError, match="container must be a dict"): + utils._resolve_container_spec(123) + + def test_invalid_type_list_raises(self, utils): + with pytest.raises(ValueError, match="container must be a dict"): + utils._resolve_container_spec([{"image": "img"}]) + + +# ============================================================ +# Tests for _apply_optional_ic_params helper +# ============================================================ + +class TestApplyOptionalIcParams: + """Tests for the static helper that wires optional IC params into a spec dict.""" + + def test_no_params_no_mutation(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params(spec) + assert "DataCacheConfig" not in spec + assert "BaseInferenceComponentName" not in spec + assert "Container" not in spec + + def test_data_cache_config_dict(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, data_cache_config={"enable_caching": True} + ) + assert spec["DataCacheConfig"] == {"EnableCaching": True} + + def test_data_cache_config_typed(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + cfg = InferenceComponentDataCacheConfig(enable_caching=False) + ModelBuilder._apply_optional_ic_params(spec, data_cache_config=cfg) + assert spec["DataCacheConfig"] == {"EnableCaching": False} + + def test_base_inference_component_name(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, base_inference_component_name="base-ic" + ) + assert spec["BaseInferenceComponentName"] == "base-ic" + + def test_container_dict(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, + container={ + "image": "img:latest", + "artifact_url": "s3://b/a", + "environment": {"K": "V"}, + }, + ) + assert spec["Container"] == { + "Image": "img:latest", + "ArtifactUrl": "s3://b/a", + "Environment": {"K": "V"}, + } + + def test_container_typed(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + c = InferenceComponentContainerSpecification(image="img") + ModelBuilder._apply_optional_ic_params(spec, container=c) + assert spec["Container"] == {"Image": "img"} + + def test_all_params_together(self): + from sagemaker.serve.model_builder import ModelBuilder + spec = {"ModelName": "m"} + ModelBuilder._apply_optional_ic_params( + spec, + data_cache_config={"enable_caching": True}, + base_inference_component_name="base", + container={"image": "img"}, + ) + assert spec["DataCacheConfig"] == {"EnableCaching": True} + assert spec["BaseInferenceComponentName"] == "base" + assert spec["Container"] == {"Image": "img"} + + +# ============================================================ +# Tests for core wiring logic in _deploy_core_endpoint +# ============================================================ + +class TestDeployCoreEndpointWiring: + """Tests that new IC parameters are correctly wired through _deploy_core_endpoint.""" + + def _make_model_builder(self): + """Create a minimally-configured ModelBuilder for testing _deploy_core_endpoint.""" + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + # Set minimum required attributes + mb.model_name = "test-model" + mb.endpoint_name = None + mb.inference_component_name = None + mb.instance_type = "ml.g5.2xlarge" + mb.instance_count = 1 + mb.accelerator_type = None + mb._tags = None + mb.kms_key = None + mb.async_inference_config = None + mb.serverless_inference_config = None + mb.model_data_download_timeout = None + mb.resource_requirements = None + mb.container_startup_health_check_timeout = None + mb.inference_ami_version = None + mb._is_sharded_model = False + mb._enable_network_isolation = False + mb.role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" + mb.vpc_config = None + mb.inference_recommender_job_results = None + mb.model_server = None + mb.mode = None + mb.region = "us-east-1" + + # Mock built_model + mb.built_model = MagicMock() + mb.built_model.model_name = "test-model" + + # Mock sagemaker_session + mb.sagemaker_session = MagicMock() + mb.sagemaker_session.endpoint_in_service_or_not.return_value = True + mb.sagemaker_session.boto_session = MagicMock() + mb.sagemaker_session.boto_region_name = "us-east-1" + + return mb + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_variant_name_defaults_to_all_traffic(self, mock_endpoint_cls): + """When variant_name is not provided, it defaults to 'AllTraffic'.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + wait=False, + ) + + # Verify create_inference_component was called with variant_name="AllTraffic" + mb.sagemaker_session.create_inference_component.assert_called_once() + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + assert call_kwargs.kwargs["variant_name"] == "AllTraffic" + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_variant_name_custom(self, mock_endpoint_cls): + """When variant_name is provided, it is used instead of 'AllTraffic'.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + variant_name="MyVariant", + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + assert call_kwargs.kwargs["variant_name"] == "MyVariant" + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_data_cache_config_wired_into_spec(self, mock_endpoint_cls): + """data_cache_config dict is resolved and added to inference_component_spec.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + data_cache_config={"enable_caching": True}, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert "DataCacheConfig" in spec + assert spec["DataCacheConfig"]["EnableCaching"] is True + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_base_inference_component_name_wired_into_spec(self, mock_endpoint_cls): + """base_inference_component_name is added to inference_component_spec.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + base_inference_component_name="base-ic-name", + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["BaseInferenceComponentName"] == "base-ic-name" + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_container_wired_into_spec(self, mock_endpoint_cls): + """container dict is resolved and added to inference_component_spec.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + container={ + "image": "my-image:latest", + "artifact_url": "s3://bucket/artifact", + "environment": {"KEY": "VALUE"}, + }, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert "Container" in spec + assert spec["Container"]["Image"] == "my-image:latest" + assert spec["Container"]["ArtifactUrl"] == "s3://bucket/artifact" + assert spec["Container"]["Environment"] == {"KEY": "VALUE"} + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_no_optional_params_no_extra_keys_in_spec(self, mock_endpoint_cls): + """When no optional IC params are provided, spec has no extra keys.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert "DataCacheConfig" not in spec + assert "BaseInferenceComponentName" not in spec + assert "Container" not in spec + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_data_cache_config_typed_object_wired(self, mock_endpoint_cls): + """InferenceComponentDataCacheConfig object is correctly wired.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + config = InferenceComponentDataCacheConfig(enable_caching=True) + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + data_cache_config=config, + wait=False, + ) + + call_kwargs = mb.sagemaker_session.create_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["DataCacheConfig"]["EnableCaching"] is True + + @patch("sagemaker.serve.model_builder.Endpoint") + def test_variant_name_passed_to_production_variant_on_new_endpoint(self, mock_endpoint_cls): + """When creating a new endpoint, variant_name is passed to production_variant.""" + mb = self._make_model_builder() + mock_endpoint_cls.get.return_value = MagicMock() + # Simulate endpoint does NOT exist yet + mb.sagemaker_session.endpoint_in_service_or_not.return_value = False + + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + with patch("sagemaker.serve.model_builder.session_helper.production_variant") as mock_pv: + mock_pv.return_value = {"VariantName": "CustomVariant"} + mb._deploy_core_endpoint( + endpoint_type="INFERENCE_COMPONENT_BASED", + resources=resources, + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + variant_name="CustomVariant", + wait=False, + ) + + # Verify production_variant was called with variant_name="CustomVariant" + mock_pv.assert_called_once() + pv_kwargs = mock_pv.call_args + assert pv_kwargs.kwargs.get("variant_name") == "CustomVariant" + + +# ============================================================ +# Tests for _update_inference_component wiring +# ============================================================ + +class TestUpdateInferenceComponentWiring: + """Tests that _update_inference_component correctly wires optional IC params.""" + + def _make_model_builder(self): + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.model_name = "test-model" + mb.sagemaker_session = MagicMock() + return mb + + def test_update_ic_with_data_cache_config(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component( + "my-ic", resources, data_cache_config={"enable_caching": True} + ) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["DataCacheConfig"] == {"EnableCaching": True} + + def test_update_ic_with_container(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component( + "my-ic", resources, container={"image": "img:v1"} + ) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["Container"] == {"Image": "img:v1"} + + def test_update_ic_with_base_inference_component_name(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component( + "my-ic", resources, base_inference_component_name="base-ic" + ) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert spec["BaseInferenceComponentName"] == "base-ic" + + def test_update_ic_no_optional_params(self): + mb = self._make_model_builder() + from sagemaker.core.inference_config import ResourceRequirements + resources = ResourceRequirements( + requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1} + ) + + mb._update_inference_component("my-ic", resources) + + call_kwargs = mb.sagemaker_session.update_inference_component.call_args + spec = call_kwargs.kwargs["specification"] + assert "DataCacheConfig" not in spec + assert "BaseInferenceComponentName" not in spec + assert "Container" not in spec + + +# ============================================================ +# Tests for deploy() parameter forwarding +# ============================================================ + +class TestDeployParameterForwarding: + """Tests that deploy() correctly forwards new IC params into kwargs.""" + + def test_deploy_forwards_variant_name_to_kwargs(self): + """deploy() should set kwargs['variant_name'] to the provided value.""" + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.built_model = MagicMock() + mb._deployed = False + mb._is_sharded_model = False + mb.model_name = "test" + mb.instance_type = "ml.m5.large" + mb.endpoint_name = None + mb.mode = None + mb.model_server = None + + # Mock _is_model_customization to return False + mb._is_model_customization = MagicMock(return_value=False) + # Mock _deploy to capture kwargs + captured = {} + + def fake_deploy(**kw): + captured.update(kw) + return MagicMock() + + mb._deploy = fake_deploy + + mb.deploy( + endpoint_name="ep", + instance_type="ml.m5.large", + initial_instance_count=1, + variant_name="MyVariant", + data_cache_config={"enable_caching": True}, + base_inference_component_name="base-ic", + container={"image": "img"}, + ) + + assert captured["variant_name"] == "MyVariant" + assert captured["data_cache_config"] == {"enable_caching": True} + assert captured["base_inference_component_name"] == "base-ic" + assert captured["container"] == {"image": "img"} + + def test_deploy_does_not_set_variant_name_when_not_provided(self): + """deploy() should NOT set variant_name in kwargs when not provided. + + This allows downstream methods to use their own defaults: + - _deploy_core_endpoint defaults to 'AllTraffic' + - _deploy_model_customization defaults to endpoint_name + """ + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.built_model = MagicMock() + mb._deployed = False + mb._is_sharded_model = False + mb.model_name = "test" + mb.instance_type = "ml.m5.large" + mb.endpoint_name = None + mb.mode = None + mb.model_server = None + mb._is_model_customization = MagicMock(return_value=False) + + captured = {} + + def fake_deploy(**kw): + captured.update(kw) + return MagicMock() + + mb._deploy = fake_deploy + + mb.deploy( + endpoint_name="ep", + instance_type="ml.m5.large", + initial_instance_count=1, + ) + + # variant_name should NOT be in kwargs when not explicitly provided + assert "variant_name" not in captured + # Optional params should not be in kwargs when not provided + assert "data_cache_config" not in captured + assert "base_inference_component_name" not in captured + assert "container" not in captured + + def test_deploy_forwards_variant_name_none_is_not_forwarded(self): + """deploy(variant_name=None) should NOT forward variant_name. + + None is the default, so it should behave the same as not providing it. + """ + from sagemaker.serve.model_builder import ModelBuilder + + mb = object.__new__(ModelBuilder) + mb.built_model = MagicMock() + mb._deployed = False + mb._is_sharded_model = False + mb.model_name = "test" + mb.instance_type = "ml.m5.large" + mb.endpoint_name = None + mb.mode = None + mb.model_server = None + mb._is_model_customization = MagicMock(return_value=False) + + captured = {} + + def fake_deploy(**kw): + captured.update(kw) + return MagicMock() + + mb._deploy = fake_deploy + + mb.deploy( + endpoint_name="ep", + instance_type="ml.m5.large", + initial_instance_count=1, + variant_name=None, + ) + + # variant_name=None should not be forwarded + assert "variant_name" not in captured