-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: Model builder unable to (5667) #5754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,7 +52,7 @@ | |
| ModelCard, | ||
| ModelPackageModelCard, | ||
| ) | ||
| from sagemaker.core.utils.utils import logger | ||
| from sagemaker.core.utils.utils import logger, Unassigned | ||
| from sagemaker.core.helper import session_helper | ||
| from sagemaker.core.helper.session_helper import ( | ||
| Session, | ||
|
|
@@ -414,6 +414,7 @@ def __post_init__(self) -> None: | |
| if self.log_level is not None: | ||
| logger.setLevel(self.log_level) | ||
|
|
||
| self._base_model_fields_resolved: bool = False | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This instance attribute is set in _base_model_fields_resolved: bool = Falsein the class body (or as a |
||
| self._warn_about_deprecated_parameters(warnings) | ||
| self._initialize_compute_config() | ||
| self._initialize_network_config() | ||
|
|
@@ -680,18 +681,182 @@ def _infer_instance_type_from_jumpstart(self) -> str: | |
|
|
||
| raise ValueError(error_msg) | ||
|
|
||
|
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing type annotation on the return type. Per SDK conventions, all public and non-trivial private methods should have type annotations: def _resolve_base_model_fields(self) -> None: |
||
| def _resolve_base_model_fields(self): | ||
| """Auto-resolve missing BaseModel fields (hub_content_version, recipe_name). | ||
|
|
||
| When a ModelPackage's BaseModel has hub_content_name set but is missing | ||
| hub_content_version and/or recipe_name (returned as Unassigned from the | ||
| DescribeModelPackage API), this method resolves them automatically: | ||
| - hub_content_version: resolved by calling HubContent.get on SageMakerPublicHub | ||
| - recipe_name: resolved from the first recipe in the hub document's RecipeCollection | ||
|
|
||
| Note: HubContent.get() supports being called without hub_content_version, | ||
| in which case it returns the latest version of the hub content. | ||
| """ | ||
|
aviruthen marked this conversation as resolved.
|
||
| if self._base_model_fields_resolved: | ||
| return | ||
|
|
||
| model_package = self._fetch_model_package() | ||
| if not model_package: | ||
|
aviruthen marked this conversation as resolved.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The early-return pattern with base_model = self._get_base_model_from_package()
if not base_model:
self._base_model_fields_resolved = True
returnThis would also make the method significantly shorter and more readable. |
||
| self._base_model_fields_resolved = True | ||
| return | ||
|
|
||
| inference_spec = getattr(model_package, "inference_specification", None) | ||
| if not inference_spec: | ||
| self._base_model_fields_resolved = True | ||
| return | ||
|
|
||
| containers = getattr(inference_spec, "containers", None) | ||
| if not containers: | ||
| self._base_model_fields_resolved = True | ||
| return | ||
|
|
||
| base_model = getattr(containers[0], "base_model", None) | ||
| if not base_model: | ||
| self._base_model_fields_resolved = True | ||
| return | ||
|
|
||
| hub_content_name = getattr(base_model, "hub_content_name", None) | ||
| if not hub_content_name or isinstance(hub_content_name, Unassigned): | ||
| self._base_model_fields_resolved = True | ||
| return | ||
|
|
||
| hub_content_version = getattr(base_model, "hub_content_version", None) | ||
| recipe_name = getattr(base_model, "recipe_name", None) | ||
|
|
||
| # Cache the HubContent response to avoid redundant API calls | ||
| cached_hub_content = None | ||
|
|
||
| # Resolve hub_content_version if missing | ||
| if not hub_content_version or isinstance(hub_content_version, Unassigned): | ||
| logger.info( | ||
| "hub_content_version is missing for hub content '%s'. " | ||
| "Resolving automatically from SageMakerPublicHub...", | ||
| hub_content_name, | ||
| ) | ||
| try: | ||
| cached_hub_content = HubContent.get( | ||
| hub_content_type="Model", | ||
| hub_name="SageMakerPublicHub", | ||
| hub_content_name=hub_content_name, | ||
| ) | ||
| base_model.hub_content_version = ( | ||
| cached_hub_content.hub_content_version | ||
| ) | ||
| logger.info( | ||
|
aviruthen marked this conversation as resolved.
|
||
| "Resolved hub_content_version to '%s' " | ||
| "for hub content '%s'.", | ||
| cached_hub_content.hub_content_version, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching except ClientError as e:
if e.response['Error']['Code'] == 'ResourceNotFoundException':
logger.warning(...)
else:
raise |
||
| hub_content_name, | ||
| ) | ||
| except (ClientError, ValueError) as e: | ||
|
aviruthen marked this conversation as resolved.
|
||
| logger.warning( | ||
| "Failed to resolve hub_content_version " | ||
| "for hub content '%s': %s", | ||
| hub_content_name, | ||
| e, | ||
| ) | ||
| self._base_model_fields_resolved = True | ||
| return | ||
|
|
||
| # Resolve recipe_name if missing | ||
| if not recipe_name or isinstance(recipe_name, Unassigned): | ||
| logger.info( | ||
| "recipe_name is missing for hub content '%s'. " | ||
| "Resolving from hub document RecipeCollection...", | ||
| hub_content_name, | ||
| ) | ||
| try: | ||
| # Reuse cached hub content if available and version matches | ||
|
aviruthen marked this conversation as resolved.
|
||
| if ( | ||
| cached_hub_content is not None | ||
| and cached_hub_content.hub_content_version | ||
| == base_model.hub_content_version | ||
| ): | ||
| hub_content = cached_hub_content | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The version comparison |
||
| else: | ||
| hub_content = HubContent.get( | ||
| hub_content_type="Model", | ||
| hub_name="SageMakerPublicHub", | ||
| hub_content_name=base_model.hub_content_name, | ||
| hub_content_version=( | ||
| base_model.hub_content_version | ||
| ), | ||
| ) | ||
| hub_document = json.loads( | ||
| hub_content.hub_content_document | ||
| ) | ||
| recipe_collection = hub_document.get( | ||
| "RecipeCollection", [] | ||
| ) | ||
| if recipe_collection: | ||
| resolved_recipe = recipe_collection[0].get( | ||
| "Name", "" | ||
| ) | ||
| if resolved_recipe: | ||
|
aviruthen marked this conversation as resolved.
|
||
| base_model.recipe_name = resolved_recipe | ||
| logger.info( | ||
| "Resolved recipe_name to '%s' " | ||
| "for hub content '%s'.", | ||
| resolved_recipe, | ||
| hub_content_name, | ||
| ) | ||
| else: | ||
|
aviruthen marked this conversation as resolved.
|
||
| logger.warning( | ||
| "RecipeCollection found but first " | ||
| "recipe has no Name for hub " | ||
| "content '%s'.", | ||
| hub_content_name, | ||
| ) | ||
| else: | ||
| logger.warning( | ||
| "No RecipeCollection found in hub " | ||
| "document for hub content '%s'. " | ||
| "recipe_name could not be " | ||
| "auto-resolved.", | ||
| hub_content_name, | ||
| ) | ||
| except (ClientError, ValueError) as e: | ||
| logger.warning( | ||
| "Failed to resolve recipe_name " | ||
| "for hub content '%s': %s", | ||
| hub_content_name, | ||
| e, | ||
| ) | ||
|
|
||
| self._base_model_fields_resolved = True | ||
|
|
||
| def _fetch_hub_document_for_custom_model(self) -> dict: | ||
| """Fetch the hub document for a custom (fine-tuned) model. | ||
|
|
||
| Calls _resolve_base_model_fields() first to ensure hub_content_version | ||
| is populated. If hub_content_version is still Unassigned after | ||
| resolution (e.g. resolution failed), HubContent.get() is called | ||
| without a version parameter, which returns the latest version. | ||
| """ | ||
| from sagemaker.core.shapes import BaseModel as CoreBaseModel | ||
|
|
||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The line break in the middle of the attribute chain is unusual and reduces readability: base_model: CoreBaseModel = (
self._fetch_model_package()
.inference_specification.containers[0].base_model
)Consider breaking it more naturally: model_package = self._fetch_model_package()
base_model: CoreBaseModel = (
model_package.inference_specification.containers[0].base_model
)This also avoids the long chained call and is consistent with how it's done in |
||
| self._resolve_base_model_fields() | ||
|
|
||
|
aviruthen marked this conversation as resolved.
|
||
| base_model: CoreBaseModel = ( | ||
| self._fetch_model_package().inference_specification.containers[0].base_model | ||
| self._fetch_model_package() | ||
| .inference_specification.containers[0].base_model | ||
| ) | ||
|
|
||
| hub_content_version = getattr( | ||
| base_model, "hub_content_version", None | ||
| ) | ||
| hub_content = HubContent.get( | ||
| get_kwargs = dict( | ||
| hub_content_type="Model", | ||
| hub_name="SageMakerPublicHub", | ||
| hub_content_name=base_model.hub_content_name, | ||
| hub_content_version=base_model.hub_content_version, | ||
| ) | ||
| if hub_content_version and not isinstance( | ||
| hub_content_version, Unassigned | ||
| ): | ||
| get_kwargs["hub_content_version"] = hub_content_version | ||
|
|
||
| hub_content = HubContent.get(**get_kwargs) | ||
| return json.loads(hub_content.hub_content_document) | ||
|
|
||
| def _fetch_hosting_configs_for_custom_model(self) -> dict: | ||
|
|
@@ -937,9 +1102,29 @@ def _is_gpu_instance(self, instance_type: str) -> bool: | |
|
|
||
| def _fetch_and_cache_recipe_config(self): | ||
| """Fetch and cache image URI, compute requirements, and s3_upload_path from recipe during build.""" | ||
| # _fetch_hub_document_for_custom_model calls _resolve_base_model_fields | ||
| # internally, so no need to call it separately here. | ||
| hub_document = self._fetch_hub_document_for_custom_model() | ||
| model_package = self._fetch_model_package() | ||
| recipe_name = model_package.inference_specification.containers[0].base_model.recipe_name | ||
| base_model = ( | ||
| model_package.inference_specification | ||
| .containers[0].base_model | ||
| ) | ||
| hub_content_name = getattr( | ||
| base_model, "hub_content_name", "unknown" | ||
| ) | ||
| recipe_name = getattr(base_model, "recipe_name", None) | ||
| if not recipe_name or isinstance(recipe_name, Unassigned): | ||
| raise ValueError( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error message uses multiple f-string concatenations ( raise ValueError(
f"recipe_name is missing from the model package's BaseModel "
f"(hub_content_name='{hub_content_name}') and could not be "
f"auto-resolved from the hub document. Please ensure the model "
f"package has a valid recipe_name set, or manually set it before "
f"calling build()."
)Also, the example in the error message ( |
||
| f"recipe_name is missing from the model package's " | ||
| f"BaseModel (hub_content_name='{hub_content_name}') " | ||
|
aviruthen marked this conversation as resolved.
|
||
| f"and could not be auto-resolved from the hub " | ||
| f"document. Please ensure the model package has a " | ||
| f"valid recipe_name set, or manually set it before " | ||
| f"calling build(). Example: model_package." | ||
| f"inference_specification.containers[0].base_model" | ||
| f".recipe_name = 'your-recipe-name'" | ||
| ) | ||
|
aviruthen marked this conversation as resolved.
|
||
|
|
||
| if not self.s3_upload_path: | ||
| self.s3_upload_path = model_package.inference_specification.containers[ | ||
|
|
@@ -1060,7 +1245,11 @@ def _is_nova_model(self) -> bool: | |
| if not base_model: | ||
| return False | ||
| recipe_name = getattr(base_model, "recipe_name", "") or "" | ||
| if isinstance(recipe_name, Unassigned): | ||
|
aviruthen marked this conversation as resolved.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good defensive fix for def _normalize_field(value, default=""):
if not value or isinstance(value, Unassigned):
return default
return value |
||
| recipe_name = "" | ||
| hub_content_name = getattr(base_model, "hub_content_name", "") or "" | ||
| if isinstance(hub_content_name, Unassigned): | ||
|
aviruthen marked this conversation as resolved.
|
||
| hub_content_name = "" | ||
| return "nova" in recipe_name.lower() or "nova" in hub_content_name.lower() | ||
|
|
||
| def _is_nova_model_for_telemetry(self) -> bool: | ||
|
|
@@ -1076,8 +1265,12 @@ def _get_nova_hosting_config(self, instance_type=None): | |
| Nova training recipes don't have hosting configs in the JumpStart hub document. | ||
| This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs(). | ||
| """ | ||
| self._resolve_base_model_fields() | ||
|
aviruthen marked this conversation as resolved.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| model_package = self._fetch_model_package() | ||
| hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name | ||
| hub_content_name = ( | ||
| model_package.inference_specification | ||
| .containers[0].base_model.hub_content_name | ||
| ) | ||
|
|
||
| configs = self._NOVA_HOSTING_CONFIGS.get(hub_content_name) | ||
| if not configs: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.