diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 6479e803bd..dd91c347f4 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -743,7 +743,7 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None s3_output_path = _get_default_s3_output_path(sagemaker_session) # Validate S3 path exists - _validate_s3_path_exists(s3_output_path, sagemaker_session) + _validate_s3_path_exists(s3_output_path, sagemaker_session, kms_key_id=kms_key_id) return OutputDataConfig( s3_output_path=s3_output_path, @@ -830,8 +830,16 @@ def _validate_eula_for_gated_model(model, accept_eula, is_gated_model): return accept_eula -def _validate_s3_path_exists(s3_path: str, sagemaker_session): - """Validate S3 path and create bucket/prefix if they don't exist.""" +def _validate_s3_path_exists(s3_path: str, sagemaker_session, kms_key_id: Optional[str] = None): + """Validate S3 path and create bucket/prefix if they don't exist. + + Args: + s3_path: S3 path to validate. + sagemaker_session: SageMaker session used to create the S3 client. + kms_key_id: Optional KMS key ID. When provided, objects created to + initialize a missing prefix are written with SSE-KMS encryption so + that validation does not fail on buckets enforcing KMS encryption. + """ if not s3_path.startswith("s3://"): raise ValueError(f"Invalid S3 path format: {s3_path}") @@ -867,7 +875,11 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session): # Create the prefix by putting an empty object if not prefix.endswith('/'): prefix += '/' - s3_client.put_object(Bucket=bucket_name, Key=prefix, Body=b'') + put_object_kwargs = {"Bucket": bucket_name, "Key": prefix, "Body": b''} + if kms_key_id: + put_object_kwargs["ServerSideEncryption"] = "aws:kms" + put_object_kwargs["SSEKMSKeyId"] = kms_key_id + s3_client.put_object(**put_object_kwargs) except Exception as e: raise ValueError(f"Failed to validate/create S3 path '{s3_path}': {str(e)}") diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index c98dea477f..cef9c62eee 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -452,7 +452,9 @@ def test__create_output_config(self, mock_validate_s3): assert config.s3_output_path == "s3://bucket/output" assert config.kms_key_id == "kms-key" - mock_validate_s3.assert_called_once_with("s3://bucket/output", mock_session) + mock_validate_s3.assert_called_once_with( + "s3://bucket/output", mock_session, kms_key_id="kms-key" + ) def test__convert_input_data_to_channels(self): @@ -558,6 +560,24 @@ def test__validate_s3_path_exists_with_prefix_not_exists(self, mock_boto_client) mock_s3_client.list_objects_v2.assert_called_once_with(Bucket="test-bucket", Prefix="prefix", MaxKeys=1) mock_s3_client.put_object.assert_called_once_with(Bucket="test-bucket", Key="prefix/", Body=b'') + @patch('boto3.client') + def test__validate_s3_path_exists_with_prefix_not_exists_kms(self, mock_boto_client): + """Test S3 path validation uses SSE-KMS when kms_key_id is provided""" + mock_session = Mock() + mock_s3_client = Mock() + mock_session.boto_session.client.return_value = mock_s3_client + mock_s3_client.list_objects_v2.return_value = {} # No contents + + _validate_s3_path_exists("s3://test-bucket/prefix", mock_session, kms_key_id="kms-key") + + mock_s3_client.put_object.assert_called_once_with( + Bucket="test-bucket", + Key="prefix/", + Body=b'', + ServerSideEncryption="aws:kms", + SSEKMSKeyId="kms-key", + ) + class TestMlflowVersionMeetsMinimum: def test_meets_minimum(self):