Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions sagemaker-core/src/sagemaker/core/spark/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
from sagemaker.core import image_uris
from sagemaker.core import s3
from sagemaker.core.local.image import _ecr_login_if_needed, _pull_image
from sagemaker.core.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
from sagemaker.core.processing import (
ProcessingInput,
ProcessingOutput,
ProcessingS3Input,
ScriptProcessor,
)
from sagemaker.core.s3 import S3Uploader
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.network import NetworkConfig
Expand All @@ -52,6 +57,18 @@
logger = logging.getLogger(__name__)


def _make_s3_processing_input(input_name: str, s3_uri: str, local_path: str) -> ProcessingInput:
"""Build a V3-compatible ProcessingInput backed by an S3 channel."""
return ProcessingInput(
input_name=input_name,
s3_input=ProcessingS3Input(
s3_uri=s3_uri,
s3_data_type="S3Prefix",
local_path=local_path,
),
)


class _SparkProcessorBase(ScriptProcessor):
"""Handles Amazon SageMaker processing tasks for jobs using Spark.

Expand Down Expand Up @@ -404,10 +421,10 @@ def _stage_configuration(self, configuration):
sagemaker_session=self.sagemaker_session,
)

conf_input = ProcessingInput(
source=s3_uri,
destination=f"{self._conf_container_base_path}{self._conf_container_input_name}",
input_name=_SparkProcessorBase._conf_container_input_name,
conf_input = _make_s3_processing_input(
input_name=self._conf_container_input_name,
s3_uri=s3_uri,
local_path=f"{self._conf_container_base_path}{self._conf_container_input_name}",
)
return conf_input

Expand Down Expand Up @@ -505,15 +522,16 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
# them to the Spark container and form the spark-submit option from a
# combination of S3 URIs and container's local input path
if use_input_channel:
input_channel = ProcessingInput(
source=input_channel_s3_uri,
destination=f"{self._conf_container_base_path}{input_channel_name}",
local_path = f"{self._conf_container_base_path}{input_channel_name}"
input_channel = _make_s3_processing_input(
input_name=input_channel_name,
s3_uri=input_channel_s3_uri,
local_path=local_path,
)
spark_opt = (
Join(on=",", values=spark_opt_s3_uris + [input_channel.destination])
Join(on=",", values=spark_opt_s3_uris + [local_path])
if spark_opt_s3_uris_has_pipeline_var
else ",".join(spark_opt_s3_uris + [input_channel.destination])
else ",".join(spark_opt_s3_uris + [local_path])
)
# If no local files were uploaded, form the spark-submit option from a list of S3 URIs
else:
Expand Down
79 changes: 79 additions & 0 deletions sagemaker-core/tests/unit/spark/test_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from unittest.mock import Mock, patch

import pytest

from sagemaker.core.spark.processing import PySparkProcessor


@pytest.fixture
def mock_session():
session = Mock()
session.boto_session = Mock()
session.boto_session.region_name = "us-west-2"
session.boto_region_name = "us-west-2"
session.sagemaker_client = Mock()
session.default_bucket = Mock(return_value="test-bucket")
session.default_bucket_prefix = "sagemaker"
session.expand_role = Mock(side_effect=lambda x: x)
session.sagemaker_config = {}
return session


def _make_processor(mock_session):
processor = PySparkProcessor(
role="arn:aws:iam::123456789012:role/SageMakerRole",
image_uri="test-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)
processor._current_job_name = "test-job"
return processor


class TestPySparkProcessorV3ProcessingInputs:
@patch("sagemaker.core.spark.processing.S3Uploader.upload_string_as_file_body")
def test_stage_configuration_builds_v3_processing_input(self, mock_upload, mock_session):
processor = _make_processor(mock_session)

config_input = processor._stage_configuration(
[{"Classification": "spark-defaults", "Properties": {"spark.app.name": "test"}}]
)

mock_upload.assert_called_once()
assert config_input.input_name == processor._conf_container_input_name
assert config_input.s3_input.s3_uri == (
"s3://test-bucket/sagemaker/test-job/input/conf/configuration.json"
)
assert config_input.s3_input.local_path == "/opt/ml/processing/input/conf"

@patch("sagemaker.core.spark.processing.S3Uploader.upload")
def test_stage_submit_deps_builds_v3_processing_input_for_local_dependencies(
self, mock_upload, mock_session, tmp_path
):
processor = _make_processor(mock_session)
dep_file = tmp_path / "dep.py"
dep_file.write_text("print('dep')", encoding="utf-8")

input_channel, spark_opt = processor._stage_submit_deps(
[str(dep_file)], processor._submit_py_files_input_channel_name
)

mock_upload.assert_called_once()
assert input_channel.input_name == processor._submit_py_files_input_channel_name
assert input_channel.s3_input.s3_uri == "s3://test-bucket/sagemaker/test-job/input/py-files"
assert input_channel.s3_input.local_path == "/opt/ml/processing/input/py-files"
assert spark_opt == "/opt/ml/processing/input/py-files"
Loading