Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions sagemaker-train/src/sagemaker/train/tuner.py
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,13 @@ def _build_training_job_definition(self, inputs):
model_trainer.stopping_condition.max_wait_time_in_seconds
)

Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
definition = HyperParameterTrainingJobDefinition(
# Propagate environment variables from ModelTrainer.
# Only include when it's a dict (even empty); omit otherwise so the
# Pydantic field stays Unassigned and is excluded during serialization.
env = model_trainer.environment
Comment thread
aviruthen marked this conversation as resolved.

# Build base kwargs for the definition
definition_kwargs = dict(
algorithm_specification=algorithm_spec,
role_arn=model_trainer.role,
input_data_config=input_data_config if input_data_config else None,
Expand All @@ -1515,10 +1521,11 @@ def _build_training_job_definition(self, inputs):
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
)

# Pass through environment variables from model_trainer
env = getattr(model_trainer, "environment", None)
if env and isinstance(env, dict):
definition.environment = env
# Include environment only when it's a dict (including empty).
if isinstance(env, dict):
definition_kwargs["environment"] = env

definition = HyperParameterTrainingJobDefinition(**definition_kwargs)

# Pass through VPC config from model_trainer
networking = getattr(model_trainer, "networking", None)
Expand Down
70 changes: 70 additions & 0 deletions sagemaker-train/tests/unit/train/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,73 @@ def test_build_training_job_definition_includes_spot_params(self):
assert isinstance(
definition.stopping_condition.max_wait_time_in_seconds, int
Comment thread
aviruthen marked this conversation as resolved.
), "Max wait time should be set"

Comment thread
aviruthen marked this conversation as resolved.
def test_build_training_job_definition_includes_environment_variables(self):
"""Test that _build_training_job_definition includes environment variables.

This test verifies the fix for GitHub issue #5613 where tuning jobs were
missing environment variables that were set on the ModelTrainer.
"""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = {
"FOO": "bar",
"RANDOM_STATE": "42",
}

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment is not None, "Environment should not be None"
assert definition.environment == {
"FOO": "bar",
"RANDOM_STATE": "42",
}, "Environment variables should match those set on ModelTrainer"

def test_build_training_job_definition_with_none_environment(self):
"""Test that _build_training_job_definition handles None environment gracefully.

When environment is None, it should not be passed to the Pydantic constructor,
so the field stays as Unassigned (excluded from serialization).
"""
from sagemaker.core.utils.utils import Unassigned

mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = None

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert isinstance(definition.environment, Unassigned), (
"Environment should be Unassigned when model_trainer.environment is None"
)

def test_build_training_job_definition_with_empty_environment(self):
"""Test that _build_training_job_definition passes through empty environment.

An empty dict is valid for the SageMaker API, so we pass it through as-is
rather than silently converting it to None.
"""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = {}

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment == {}, (
"Empty dict environment should be passed through as-is"
)
39 changes: 35 additions & 4 deletions sagemaker-train/tests/unit/train/test_tuner_driver_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,31 @@ def test_passes_environment_variables(self):
definition = tuner._build_training_job_definition(inputs=None)
assert definition.environment == {"MY_VAR": "value", "OTHER": "123"}

def test_passes_empty_environment(self):
"""Should pass through empty dict environment as-is.

An empty dict is valid for the SageMaker API, so we pass it through
rather than silently converting it to None/Unassigned.
"""
trainer = _mock_model_trainer(environment={})

tuner = HyperparameterTuner(
model_trainer=trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_hp_ranges(),
)

definition = tuner._build_training_job_definition(inputs=None)
assert definition.environment == {}, (
"Empty dict environment should be passed through as-is"
)

def test_skips_environment_when_none(self):
"""Should not set environment when model_trainer.environment is None."""
"""Should not set environment when model_trainer.environment is None.

When environment is None, it is not passed to the Pydantic constructor,
so the field stays as Unassigned (excluded from serialization).
"""
trainer = _mock_model_trainer(environment=None)

tuner = HyperparameterTuner(
Expand All @@ -416,10 +439,16 @@ def test_skips_environment_when_none(self):
)

definition = tuner._build_training_job_definition(inputs=None)
assert _is_unassigned(definition.environment)
assert _is_unassigned(definition.environment), (
"Environment should be Unassigned when model_trainer.environment is None"
)

def test_skips_environment_when_not_dict(self):
"""Should not set environment when it's not a dict (e.g. MagicMock)."""
"""Should not set environment when it's not a dict (e.g. MagicMock).

Non-dict values are not passed to the Pydantic constructor to avoid
validation errors. The field stays as Unassigned.
"""
trainer = _mock_model_trainer(environment=MagicMock())

tuner = HyperparameterTuner(
Expand All @@ -429,7 +458,9 @@ def test_skips_environment_when_not_dict(self):
)

definition = tuner._build_training_job_definition(inputs=None)
assert _is_unassigned(definition.environment)
assert _is_unassigned(definition.environment), (
"Environment should be Unassigned when model_trainer.environment is not a dict"
)

def test_passes_vpc_config(self):
"""Should set definition.vpc_config from model_trainer.networking._to_vpc_config()."""
Expand Down
Loading