Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 121 additions & 61 deletions sagemaker-serve/src/sagemaker/serve/model_builder.py
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two fixes needed:

  1. Bug: variant_name always overrides model customization default. In deploy(), kwargs["variant_name"] = variant_name or "AllTraffic" always sets the key, so _deploy_model_customization never sees None and its backward-compat default of endpoint_name is dead code. Fix: only forward variant_name when explicitly provided:
# Replace:
kwargs["variant_name"] = variant_name or "AllTraffic"
# With:
if variant_name is not None:
    kwargs["variant_name"] = variant_name

Each downstream path already has its own default — _deploy_core_endpoint defaults to "AllTraffic" via kwargs.get("variant_name", "AllTraffic"), and _deploy_model_customization defaults to endpoint_name via variant_name or endpoint_name or "AllTraffic".

  1. Drop the second integ test (test_deploy_with_data_cache_config_via_model_customization_path). The model customization path requires ml.g5.4xlarge which has a non-adjustable account quota of 2 instances. When CI runs tests in parallel, this test competes with the existing test_model_customization_deployment.py for the same quota, causing flaky InsufficientInstanceCapacity failures. The model customization path's data_cache_config and variant_name wiring is already covered by unit tests. Keep only the first integ test (test_deploy_with_data_cache_config_and_variant_name_via_ic_path) which uses ml.g5.2xlarge.
    Also remove the TRAINING_JOB_NAME constant since it's no longer needed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not worry about CI failures! Removing the second integ test will fix one failure and the other failures are due to flakiness

Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
ModelLifeCycle,
DriftCheckBaselines,
InferenceComponentComputeResourceRequirements,
InferenceComponentDataCacheConfig,
Comment thread
aviruthen marked this conversation as resolved.
InferenceComponentContainerSpecification,
)
from sagemaker.core.resources import (
ModelPackage,
Expand Down Expand Up @@ -2700,6 +2702,52 @@ def _wait_for_endpoint(

return desc

@staticmethod
def _apply_optional_ic_params(inference_component_spec, **kwargs):
"""Apply optional IC-level parameters to an inference component spec dict.

Wires data_cache_config, base_inference_component_name, and container
into the given inference_component_spec dict. Shared by
_deploy_core_endpoint and _update_inference_component to avoid
code duplication.

Args:
inference_component_spec (dict): The spec dict to mutate in-place.
**kwargs: May contain data_cache_config, base_inference_component_name,
and container.
"""
from sagemaker.serve.model_builder_utils import _ModelBuilderUtils

ic_data_cache_config = kwargs.get("data_cache_config")
if ic_data_cache_config is not None:
resolved = _ModelBuilderUtils._resolve_data_cache_config(
None, ic_data_cache_config
)
if resolved is not None:
inference_component_spec["DataCacheConfig"] = {
"EnableCaching": resolved.enable_caching
}

ic_base_component_name = kwargs.get("base_inference_component_name")
if ic_base_component_name is not None:
inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name

ic_container = kwargs.get("container")
if ic_container is not None:
resolved_container = _ModelBuilderUtils._resolve_container_spec(
None, ic_container
)
if resolved_container is not None:
container_dict = {}
if resolved_container.image:
container_dict["Image"] = resolved_container.image
if resolved_container.artifact_url:
container_dict["ArtifactUrl"] = resolved_container.artifact_url
if resolved_container.environment:
container_dict["Environment"] = resolved_container.environment
if container_dict:
inference_component_spec["Container"] = container_dict

def _deploy_core_endpoint(self, **kwargs):
# Extract and update self parameters
initial_instance_count = kwargs.get(
Expand Down Expand Up @@ -2849,62 +2897,6 @@ def _deploy_core_endpoint(self, **kwargs):
if self.role_arn is None:
raise ValueError("Role can not be null for deploying a model")

routing_config = _resolve_routing_config(routing_config)

if (
inference_recommendation_id is not None
or self.inference_recommender_job_results is not None
):
instance_type, initial_instance_count = self._update_params(
instance_type=instance_type,
initial_instance_count=initial_instance_count,
accelerator_type=accelerator_type,
async_inference_config=async_inference_config,
serverless_inference_config=serverless_inference_config,
explainer_config=explainer_config,
inference_recommendation_id=inference_recommendation_id,
inference_recommender_job_results=self.inference_recommender_job_results,
)

is_async = async_inference_config is not None
if is_async and not isinstance(async_inference_config, AsyncInferenceConfig):
raise ValueError("async_inference_config needs to be a AsyncInferenceConfig object")

is_explainer_enabled = explainer_config is not None
if is_explainer_enabled and not isinstance(explainer_config, ExplainerConfig):
raise ValueError("explainer_config needs to be a ExplainerConfig object")

is_serverless = serverless_inference_config is not None
if not is_serverless and not (instance_type and initial_instance_count):
raise ValueError(
"Must specify instance type and instance count unless using serverless inference"
)

if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig):
raise ValueError(
"serverless_inference_config needs to be a ServerlessInferenceConfig object"
)

if self._is_sharded_model:
if endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logger.warning(
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
)
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED

if self._enable_network_isolation:
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)

if resources and resources.num_cpus and resources.num_cpus > 0:
logger.warning(
"NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker "
"Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`."
)

if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
if update_endpoint:
raise ValueError(
Expand All @@ -2931,10 +2923,14 @@ def _deploy_core_endpoint(self, **kwargs):
else:
Comment thread
aviruthen marked this conversation as resolved.
managed_instance_scaling_config["MinInstanceCount"] = initial_instance_count

# Use user-provided variant_name or default to "AllTraffic"
ic_variant_name = kwargs.get("variant_name", "AllTraffic")

if not self.sagemaker_session.endpoint_in_service_or_not(self.endpoint_name):
production_variant = session_helper.production_variant(
instance_type=instance_type,
initial_instance_count=initial_instance_count,
variant_name=ic_variant_name,
volume_size=volume_size,
model_data_download_timeout=model_data_download_timeout,
container_startup_health_check_timeout=container_startup_health_check_timeout,
Expand Down Expand Up @@ -2978,6 +2974,10 @@ def _deploy_core_endpoint(self, **kwargs):
"StartupParameters": startup_parameters,
Comment thread
aviruthen marked this conversation as resolved.
"ComputeResourceRequirements": resources.get_compute_resource_requirements(),
}

# Wire optional IC-level parameters into the specification
self._apply_optional_ic_params(inference_component_spec, **kwargs)

runtime_config = {"CopyCount": resources.copy_count}
self.inference_component_name = (
inference_component_name
Expand All @@ -2989,7 +2989,7 @@ def _deploy_core_endpoint(self, **kwargs):
self.sagemaker_session.create_inference_component(
inference_component_name=self.inference_component_name,
endpoint_name=self.endpoint_name,
variant_name="AllTraffic", # default variant name
variant_name=ic_variant_name,
specification=inference_component_spec,
runtime_config=runtime_config,
tags=tags,
Expand Down Expand Up @@ -3168,6 +3168,10 @@ def _update_inference_component(
"StartupParameters": startup_parameters,
"ComputeResourceRequirements": compute_rr,
}

# Wire optional IC-level parameters into the update specification
self._apply_optional_ic_params(inference_component_spec, **kwargs)

runtime_config = {"CopyCount": resource_requirements.copy_count}

return self.sagemaker_session.update_inference_component(
Expand Down Expand Up @@ -4127,6 +4131,11 @@ def deploy(
] = None,
custom_orchestrator_instance_type: str = None,
custom_orchestrator_initial_instance_count: int = None,
inference_component_name: Optional[str] = None,
Comment thread
aviruthen marked this conversation as resolved.
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
base_inference_component_name: Optional[str] = None,
container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None,
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
variant_name: Optional[str] = None,
**kwargs,
) -> Union[Endpoint, LocalEndpoint, Transformer]:
"""Deploy the built model to an ``Endpoint``.
Expand Down Expand Up @@ -4160,6 +4169,26 @@ def deploy(
orchestrator deployment. (Default: None).
Comment thread
aviruthen marked this conversation as resolved.
custom_orchestrator_initial_instance_count (int, optional): Initial instance count
for custom orchestrator deployment. (Default: None).
inference_component_name (str, optional): The name of the inference component
to create. Only used for inference-component-based endpoints. If not specified,
a unique name is generated from the model name. (Default: None).
data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional):
Data cache configuration for the inference component. Enables caching of model
artifacts and container images on instances for faster auto-scaling cold starts.
Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an
InferenceComponentDataCacheConfig instance. (Default: None).
base_inference_component_name (str, optional): Name of the base inference component
for adapter deployments (e.g., LoRA adapters attached to a base model).
(Default: None).
container (Union[InferenceComponentContainerSpecification, dict], optional):
Custom container specification for the inference component, including image URI,
artifact URL, and environment variables. Can be a dict with keys 'image',
'artifact_url', 'environment' or an InferenceComponentContainerSpecification
instance. (Default: None).
variant_name (str, optional): The name of the production variant to deploy to.
If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``.
Comment thread
aviruthen marked this conversation as resolved.
(Default: None).

Returns:
Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint``
resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode,
Expand All @@ -4182,6 +4211,21 @@ def deploy(
if not hasattr(self, "built_model") and not hasattr(self, "_deployables"):
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
raise ValueError("Model needs to be built before deploying")

# Only forward variant_name when explicitly provided by the caller.
# Each downstream path has its own default:
# - _deploy_core_endpoint defaults to "AllTraffic"
# - _deploy_model_customization defaults to endpoint_name
if variant_name is not None:
kwargs["variant_name"] = variant_name
if inference_component_name is not None:
kwargs["inference_component_name"] = inference_component_name
if data_cache_config is not None:
kwargs["data_cache_config"] = data_cache_config
if base_inference_component_name is not None:
kwargs["base_inference_component_name"] = base_inference_component_name
if container is not None:
kwargs["container"] = container
Comment thread
aviruthen marked this conversation as resolved.

# Handle model customization deployment
if self._is_model_customization():
logger.info("Deploying Model Customization model")
Expand Down Expand Up @@ -4338,8 +4382,13 @@ def _deploy_model_customization(
initial_instance_count: int = 1,
inference_component_name: Optional[str] = None,
inference_config: Optional[ResourceRequirements] = None,
variant_name: Optional[str] = None,
Comment thread
aviruthen marked this conversation as resolved.
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
**kwargs,
) -> Endpoint:
# NOTE: For backward compatibility, model customization deployments
# default variant_name to endpoint_name (not "AllTraffic") when the
# caller does not provide an explicit value.
"""Deploy a model customization (fine-tuned) model to an endpoint with inference components.

This method handles the special deployment flow for fine-tuned models, creating:
Expand Down Expand Up @@ -4379,6 +4428,15 @@ def _deploy_model_customization(
# Fetch model package
model_package = self._fetch_model_package()

# Resolve variant_name: preserve backward-compatible default of
# endpoint_name for model customization deployments.
effective_variant_name = variant_name or endpoint_name or "AllTraffic"

# Resolve data_cache_config if provided
resolved_data_cache_config = None
if data_cache_config is not None:
resolved_data_cache_config = self._resolve_data_cache_config(data_cache_config)

# Check if endpoint exists
is_existing_endpoint = self._does_endpoint_exist(endpoint_name)

Expand All @@ -4387,7 +4445,7 @@ def _deploy_model_customization(
endpoint_config_name=endpoint_name,
Comment thread
aviruthen marked this conversation as resolved.
production_variants=[
ProductionVariant(
variant_name=endpoint_name,
variant_name=effective_variant_name,
instance_type=self.instance_type,
initial_instance_count=initial_instance_count or 1,
)
Expand Down Expand Up @@ -4428,6 +4486,7 @@ def _deploy_model_customization(

base_ic_spec = InferenceComponentSpecification(
model_name=self.built_model.model_name,
data_cache_config=resolved_data_cache_config,
)
if inference_config is not None:
base_ic_spec.compute_resource_requirements = (
Expand All @@ -4444,7 +4503,7 @@ def _deploy_model_customization(
InferenceComponent.create(
inference_component_name=base_ic_name,
endpoint_name=endpoint_name,
variant_name=endpoint_name,
variant_name=effective_variant_name,
specification=base_ic_spec,
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
tags=[{"key": "Base", "value": base_model_recipe_name}],
Expand Down Expand Up @@ -4486,7 +4545,8 @@ def _deploy_model_customization(
ic_spec = InferenceComponentSpecification(
container=InferenceComponentContainerSpecification(
image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars
)
),
data_cache_config=resolved_data_cache_config,
)

if inference_config is not None:
Expand All @@ -4504,7 +4564,7 @@ def _deploy_model_customization(
InferenceComponent.create(
inference_component_name=inference_component_name,
endpoint_name=endpoint_name,
variant_name=endpoint_name,
variant_name=effective_variant_name,
specification=ic_spec,
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
)
Expand Down
78 changes: 78 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def build(self):
from sagemaker.serve.utils.hardware_detector import _total_inference_model_size_mib
from sagemaker.serve.utils.types import ModelServer
from sagemaker.core.resources import Model
from sagemaker.core.shapes import (
InferenceComponentDataCacheConfig,
InferenceComponentContainerSpecification,
)

# MLflow imports
from sagemaker.serve.model_format.mlflow.constants import (
Expand Down Expand Up @@ -3369,6 +3373,80 @@ def _extract_speculative_draft_model_provider(

Comment thread
aviruthen marked this conversation as resolved.
return "auto"
Comment thread
aviruthen marked this conversation as resolved.

Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
def _resolve_data_cache_config(
Comment thread
aviruthen marked this conversation as resolved.
self,
data_cache_config: Union[InferenceComponentDataCacheConfig, Dict[str, Any], None],
) -> Optional[InferenceComponentDataCacheConfig]:
"""Resolve data_cache_config to InferenceComponentDataCacheConfig.

Args:
data_cache_config: Either a dict with 'enable_caching' key (and any future
fields supported by InferenceComponentDataCacheConfig),
an InferenceComponentDataCacheConfig instance, or None.

Returns:
InferenceComponentDataCacheConfig or None.

Raises:
ValueError: If data_cache_config is an unsupported type or dict
is missing the required 'enable_caching' key.
"""
if data_cache_config is None:
return None

Comment thread
aviruthen marked this conversation as resolved.
if isinstance(data_cache_config, InferenceComponentDataCacheConfig):
return data_cache_config
elif isinstance(data_cache_config, dict):
if "enable_caching" not in data_cache_config:
raise ValueError(
"data_cache_config dict must contain the required 'enable_caching' key. "
"Example: {'enable_caching': True}"
)
# Pass only 'enable_caching' to avoid Pydantic validation errors
# if the model has extra='forbid'. As new fields are added to
# InferenceComponentDataCacheConfig, add them here.
return InferenceComponentDataCacheConfig(
enable_caching=data_cache_config["enable_caching"]
)
else:
raise ValueError(
f"data_cache_config must be a dict with 'enable_caching' key or an "
f"InferenceComponentDataCacheConfig instance, got {type(data_cache_config)}"
)

Comment thread
aviruthen marked this conversation as resolved.
def _resolve_container_spec(
self,
container: Union[InferenceComponentContainerSpecification, Dict[str, Any], None],
) -> Optional[InferenceComponentContainerSpecification]:
"""Resolve container to InferenceComponentContainerSpecification.
Comment thread
aviruthen marked this conversation as resolved.

Args:
container: Either a dict with container config keys (image, artifact_url,
environment), an InferenceComponentContainerSpecification instance, or None.

Returns:
InferenceComponentContainerSpecification or None.

Raises:
ValueError: If container is an unsupported type.
"""
if container is None:
return None

if isinstance(container, InferenceComponentContainerSpecification):
return container
elif isinstance(container, dict):
# Only pass known keys to avoid Pydantic validation errors
# if the model has extra='forbid' configured
known_keys = {"image", "artifact_url", "environment"}
filtered = {k: v for k, v in container.items() if k in known_keys}
return InferenceComponentContainerSpecification(**filtered)
else:
raise ValueError(
f"container must be a dict or an InferenceComponentContainerSpecification "
f"instance, got {type(container)}"
)

def get_huggingface_model_metadata(
self, model_id: str, hf_hub_token: Optional[str] = None
) -> dict:
Expand Down
Loading
Loading