Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions sagemaker-core/src/sagemaker/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,22 +360,17 @@ def __init__(
logger.debug("No config provided. Using default config.")
config = Config(retries={"max_attempts": 10, "mode": "standard"})

self.config = Config(user_agent_extra=get_user_agent_extra_suffix())
# Merge the provided config with user_agent_extra suffix
self.config = config.merge(Config(user_agent_extra=get_user_agent_extra_suffix()))
self.session = session
self.region_name = region_name
# Read region from environment variable, default to us-west-2
import os
env_region = os.environ.get('SAGEMAKER_REGION', region_name)
env_stage = os.environ.get('SAGEMAKER_STAGE', 'prod') # default to gamma
logger.info(f"Runs on sagemaker {env_stage}, region:{env_region}")


self.sagemaker_client = session.client(
"sagemaker",
region_name=env_region,
region_name=region_name,
config=self.config,
)

self.sagemaker_runtime_client = session.client(
"sagemaker-runtime", region_name, config=self.config
)
Expand Down Expand Up @@ -538,7 +533,9 @@ def _serialize_dict(value: Dict) -> dict:
dict: The serialized dict
"""
serialized_dict = {}
# Drop only Unassigned/None; preserve valid falsy values like False, 0, "".
# Drop only Unassigned/None; preserve valid falsy values like False, 0, "", [], {}.
# Note: empty containers ([] and {}) are now preserved where previously they were
# filtered as falsy. This is intentional to correctly handle all non-None values.
for k, v in value.items():
serialize_result = serialize(v)
if serialize_result is not None:
Expand All @@ -557,7 +554,9 @@ def _serialize_list(value: List) -> list:
list: The serialized list
"""
serialized_list = []
# Drop only Unassigned/None; preserve valid falsy values like False, 0, "".
# Drop only Unassigned/None; preserve valid falsy values like False, 0, "", [], {}.
# Note: empty containers ([] and {}) are now preserved where previously they were
# filtered as falsy. This is intentional to correctly handle all non-None values.
for v in value:
serialize_result = serialize(v)
if serialize_result is not None:
Expand All @@ -576,8 +575,12 @@ def _serialize_shape(value: Any) -> dict:
dict: The dict of serialized shape
"""
serialized_dict = {}
# Drop only Unassigned/None; preserve valid falsy values like False, 0, "", [], {}.
# Note: empty containers ([] and {}) are now preserved where previously they were
# filtered as falsy. This is intentional to correctly handle all non-None values.
for k, v in vars(value).items():
if serialize_result := serialize(v):
serialize_result = serialize(v)
if serialize_result is not None:
key = snake_to_pascal(k) if is_snake_case(k) else k
serialized_dict.update({key[0].upper() + key[1:]: serialize_result})
return serialized_dict
34 changes: 31 additions & 3 deletions sagemaker-core/tests/unit/generated/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@
TrialComponent,
TrialComponentParameterValue,
)
from sagemaker.core.utils.utils import *
from sagemaker.core.utils.utils import (
configure_logging,
is_snake_case,
snake_to_pascal,
pascal_to_snake,
is_not_primitive,
is_primitive_class,
serialize,
ResourceIterator,
SageMakerClient,
Unassigned,
)


LIST_TRAINING_JOB_RESPONSE_WITH_NEXT_TOKEN = {
Expand Down Expand Up @@ -266,7 +277,6 @@ def test_next_with_custom_key_mapping(resource_iterator_with_custom_key_mapping)
try:
next_item = next(iterator)
assert isinstance(next_item, DataQualityJobDefinition)
print(next_item)
expected_data_quality_job_definition_data = (
LIST_DATA_QUALITY_JOB_DEFINITION_RESPONSE_WITHOUT_NEXT_TOKEN[
"JobDefinitionSummaries"
Expand Down Expand Up @@ -303,7 +313,6 @@ def test_next_with_primitive_class(resource_iterator_with_primitive_class):
try:
next_item = next(iterator)
assert isinstance(next_item, str)
print(next_item)
expected_image_version_alias_data = LIST_ALIASES_RESPONSE_WITHOUT_NEXT_TOKEN[
"SageMakerImageVersionAliases"
][index]
Expand Down Expand Up @@ -387,6 +396,25 @@ def test_serialize_method_nested_shape():
}


def test_serialize_shape_preserves_falsy_attribute_values():
"""Regression: _serialize_shape should preserve falsy values like empty string, 0, False."""
# TrialComponentParameterValue with number_value=0 should preserve the 0
param_zero = TrialComponentParameterValue(number_value=0)
serialized = serialize(param_zero)
assert serialized["NumberValue"] == 0

# TrialComponentParameterValue with string_value="" should preserve the empty string
param_empty_str = TrialComponentParameterValue(string_value="")
serialized = serialize(param_empty_str)
assert serialized["StringValue"] == ""

# Test that Unassigned values are still excluded
param_unassigned = TrialComponentParameterValue()
serialized = serialize(param_unassigned)
assert "NumberValue" not in serialized
assert "StringValue" not in serialized


class TestUnassignedBehavior:
"""Test Unassigned class methods for proper behavior.

Expand Down
Loading