diff --git a/sagemaker-core/src/sagemaker/core/spark/processing.py b/sagemaker-core/src/sagemaker/core/spark/processing.py index 82cdef954c..e49a9764c4 100644 --- a/sagemaker-core/src/sagemaker/core/spark/processing.py +++ b/sagemaker-core/src/sagemaker/core/spark/processing.py @@ -37,7 +37,12 @@ from sagemaker.core import image_uris from sagemaker.core import s3 from sagemaker.core.local.image import _ecr_login_if_needed, _pull_image -from sagemaker.core.processing import ProcessingInput, ProcessingOutput, ScriptProcessor +from sagemaker.core.processing import ( + ProcessingInput, + ProcessingOutput, + ProcessingS3Input, + ScriptProcessor, +) from sagemaker.core.s3 import S3Uploader from sagemaker.core.helper.session_helper import Session from sagemaker.core.network import NetworkConfig @@ -52,6 +57,18 @@ logger = logging.getLogger(__name__) +def _make_s3_processing_input(input_name: str, s3_uri: str, local_path: str) -> ProcessingInput: + """Build a V3-compatible ProcessingInput backed by an S3 channel.""" + return ProcessingInput( + input_name=input_name, + s3_input=ProcessingS3Input( + s3_uri=s3_uri, + s3_data_type="S3Prefix", + local_path=local_path, + ), + ) + + class _SparkProcessorBase(ScriptProcessor): """Handles Amazon SageMaker processing tasks for jobs using Spark. @@ -404,10 +421,10 @@ def _stage_configuration(self, configuration): sagemaker_session=self.sagemaker_session, ) - conf_input = ProcessingInput( - source=s3_uri, - destination=f"{self._conf_container_base_path}{self._conf_container_input_name}", - input_name=_SparkProcessorBase._conf_container_input_name, + conf_input = _make_s3_processing_input( + input_name=self._conf_container_input_name, + s3_uri=s3_uri, + local_path=f"{self._conf_container_base_path}{self._conf_container_input_name}", ) return conf_input @@ -505,15 +522,16 @@ def _stage_submit_deps(self, submit_deps, input_channel_name): # them to the Spark container and form the spark-submit option from a # combination of S3 URIs and container's local input path if use_input_channel: - input_channel = ProcessingInput( - source=input_channel_s3_uri, - destination=f"{self._conf_container_base_path}{input_channel_name}", + local_path = f"{self._conf_container_base_path}{input_channel_name}" + input_channel = _make_s3_processing_input( input_name=input_channel_name, + s3_uri=input_channel_s3_uri, + local_path=local_path, ) spark_opt = ( - Join(on=",", values=spark_opt_s3_uris + [input_channel.destination]) + Join(on=",", values=spark_opt_s3_uris + [local_path]) if spark_opt_s3_uris_has_pipeline_var - else ",".join(spark_opt_s3_uris + [input_channel.destination]) + else ",".join(spark_opt_s3_uris + [local_path]) ) # If no local files were uploaded, form the spark-submit option from a list of S3 URIs else: diff --git a/sagemaker-core/tests/unit/spark/test_processing.py b/sagemaker-core/tests/unit/spark/test_processing.py new file mode 100644 index 0000000000..dd9f692acf --- /dev/null +++ b/sagemaker-core/tests/unit/spark/test_processing.py @@ -0,0 +1,79 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from unittest.mock import Mock, patch + +import pytest + +from sagemaker.core.spark.processing import PySparkProcessor + + +@pytest.fixture +def mock_session(): + session = Mock() + session.boto_session = Mock() + session.boto_session.region_name = "us-west-2" + session.boto_region_name = "us-west-2" + session.sagemaker_client = Mock() + session.default_bucket = Mock(return_value="test-bucket") + session.default_bucket_prefix = "sagemaker" + session.expand_role = Mock(side_effect=lambda x: x) + session.sagemaker_config = {} + return session + + +def _make_processor(mock_session): + processor = PySparkProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + return processor + + +class TestPySparkProcessorV3ProcessingInputs: + @patch("sagemaker.core.spark.processing.S3Uploader.upload_string_as_file_body") + def test_stage_configuration_builds_v3_processing_input(self, mock_upload, mock_session): + processor = _make_processor(mock_session) + + config_input = processor._stage_configuration( + [{"Classification": "spark-defaults", "Properties": {"spark.app.name": "test"}}] + ) + + mock_upload.assert_called_once() + assert config_input.input_name == processor._conf_container_input_name + assert config_input.s3_input.s3_uri == ( + "s3://test-bucket/sagemaker/test-job/input/conf/configuration.json" + ) + assert config_input.s3_input.local_path == "/opt/ml/processing/input/conf" + + @patch("sagemaker.core.spark.processing.S3Uploader.upload") + def test_stage_submit_deps_builds_v3_processing_input_for_local_dependencies( + self, mock_upload, mock_session, tmp_path + ): + processor = _make_processor(mock_session) + dep_file = tmp_path / "dep.py" + dep_file.write_text("print('dep')", encoding="utf-8") + + input_channel, spark_opt = processor._stage_submit_deps( + [str(dep_file)], processor._submit_py_files_input_channel_name + ) + + mock_upload.assert_called_once() + assert input_channel.input_name == processor._submit_py_files_input_channel_name + assert input_channel.s3_input.s3_uri == "s3://test-bucket/sagemaker/test-job/input/py-files" + assert input_channel.s3_input.local_path == "/opt/ml/processing/input/py-files" + assert spark_opt == "/opt/ml/processing/input/py-files"