From cc1730a2ccbbfc193efde0c773d2b36916cf2827 Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Mon, 1 Jun 2026 00:28:42 -0700 Subject: [PATCH] fix: (5873) --- .../src/sagemaker/core/processing.py | 346 ++++++++++++++---- 1 file changed, 283 insertions(+), 63 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/processing.py b/sagemaker-core/src/sagemaker/core/processing.py index 62493719dc..37e35184d8 100644 --- a/sagemaker-core/src/sagemaker/core/processing.py +++ b/sagemaker-core/src/sagemaker/core/processing.py @@ -25,10 +25,10 @@ import re from typing import Dict, List, Optional, Union import time -from copy import copy +from copy import deepcopy from textwrap import dedent -from six.moves.urllib.parse import urlparse -from six.moves.urllib.request import url2pathname +from urllib.parse import urlparse +from urllib.request import url2pathname from sagemaker.core.network import NetworkConfig from sagemaker.core import s3 from sagemaker.core.apiutils._base_types import ApiObject @@ -94,7 +94,7 @@ class Processor(object): def __init__( self, - role: str = None, + role: Optional[Union[str, PipelineVariable]] = None, image_uri: Union[str, PipelineVariable] = None, instance_count: Union[int, PipelineVariable] = None, instance_type: Union[str, PipelineVariable] = None, @@ -420,12 +420,29 @@ def _normalize_inputs(self, inputs=None, kms_key=None): if file_input.dataset_definition: normalized_inputs.append(file_input) continue - if file_input.s3_input and is_pipeline_variable(file_input.s3_input.s3_uri): + + # Guard against s3_input being None + if file_input.s3_input is None: + normalized_inputs.append(file_input) + continue + + # Guard against s3_uri being None + if file_input.s3_input.s3_uri is None: + normalized_inputs.append(file_input) + continue + + if is_pipeline_variable(file_input.s3_input.s3_uri): normalized_inputs.append(file_input) continue # If the s3_uri is not an s3_uri, create one. parse_result = urlparse(file_input.s3_input.s3_uri) if parse_result.scheme != "s3": + is_local_mode = getattr( + self.sagemaker_session, "local_mode", False + ) + if is_local_mode and parse_result.scheme == "file": + normalized_inputs.append(file_input) + continue if _pipeline_config: desired_s3_uri = s3.s3_path_join( "s3://", @@ -483,13 +500,27 @@ def _normalize_outputs(self, outputs=None): # Generate a name for the ProcessingOutput if it doesn't have one. if output.output_name is None: output.output_name = "output-{}".format(count) - if output.s3_output and is_pipeline_variable(output.s3_output.s3_uri): + + # Guard against s3_output being None + if output.s3_output is None: + normalized_outputs.append(output) + continue + + # Guard against s3_uri being None + if output.s3_output.s3_uri is None: + normalized_outputs.append(output) + continue + + if is_pipeline_variable(output.s3_output.s3_uri): normalized_outputs.append(output) continue # If the output's s3_uri is not an s3_uri, create one. parse_result = urlparse(output.s3_output.s3_uri) if parse_result.scheme != "s3": - if getattr(self.sagemaker_session, "local_mode", False) and parse_result.scheme == "file": + is_local_mode = getattr( + self.sagemaker_session, "local_mode", False + ) + if is_local_mode and parse_result.scheme == "file": normalized_outputs.append(output) continue if _pipeline_config: @@ -525,7 +556,17 @@ def _normalize_outputs(self, outputs=None): return normalized_outputs def _start_new(self, inputs, outputs, experiment_config): - """Starts a new processing job and returns ProcessingJob instance.""" + """Starts a new processing job and returns ProcessingJob instance. + + For PipelineSession, this method returns None since the job is not immediately + created but instead registered as a pipeline step. In this case, + ``self.latest_job`` will be set to None by the caller. Callers should + check ``self.latest_job`` for None before using it. + + Returns: + ProcessingJob: The created processing job instance, or None if using + PipelineSession (job creation is deferred to pipeline execution). + """ from sagemaker.core.workflow.pipeline_context import PipelineSession process_args = self._get_process_args(inputs, outputs, experiment_config) @@ -606,14 +647,27 @@ def _start_new(self, inputs, outputs, experiment_config): serialized_request = serialize(process_request) if isinstance(self.sagemaker_session, PipelineSession): - self.sagemaker_session._intercept_create_request(serialized_request, None, "process") - return + # For PipelineSession, the job is not created immediately. The request is + # intercepted and stored for later pipeline execution. No submit callback + # is needed since the pipeline orchestrator handles job creation. + # Returns None intentionally - self.latest_job will be None for pipeline sessions. + self.sagemaker_session._intercept_create_request( + serialized_request, None, "process" + ) + return None def submit(request): try: - logger.info("Creating processing-job with name %s", process_args["job_name"]) - logger.debug("process request: %s", json.dumps(request, indent=4)) - self.sagemaker_session.sagemaker_client.create_processing_job(**request) + logger.info( + "Creating processing-job with name %s", + process_args["job_name"], + ) + logger.debug( + "process request: %s", json.dumps(request, indent=4) + ) + self.sagemaker_session.sagemaker_client.create_processing_job( + **request + ) except Exception as e: troubleshooting = ( "https://docs.aws.amazon.com/sagemaker/latest/dg/" @@ -621,11 +675,21 @@ def submit(request): "#sagemaker-python-sdk-troubleshooting-create-processing-job" ) logger.error( - "Please check the troubleshooting guide for common errors: %s", troubleshooting + "Failed to create processing job '%s'. " + "Please check the troubleshooting guide for common errors: %s", + process_args["job_name"], + troubleshooting, ) raise e - self.sagemaker_session._intercept_create_request(serialized_request, submit, "process") + # Contract: _intercept_create_request is guaranteed to invoke the submit callback + # for non-PipelineSession sessions. It may modify serialized_request in-place + # (e.g., to add request metadata) before calling submit. The same dict reference + # is used below to construct the ProcessingJob, so any in-place modifications + # by the interceptor will be reflected in the resulting ProcessingJob object. + self.sagemaker_session._intercept_create_request( + serialized_request, submit, "process" + ) from sagemaker.core.utils.code_injection.codec import transform @@ -1123,11 +1187,14 @@ def _s3_code_prefix(self): """Return the S3 prefix for code uploads, respecting code_location if set.""" if self.code_location: return self.code_location - return s3.s3_path_join( + prefix = self.sagemaker_session.default_bucket_prefix + parts = [ "s3://", self.sagemaker_session.default_bucket(), - self.sagemaker_session.default_bucket_prefix or "", - ) + ] + if prefix: + parts.append(prefix) + return s3.s3_path_join(*parts) def _package_code( self, @@ -1137,31 +1204,114 @@ def _package_code( job_name, kms_key, ): - """Package and upload code to S3.""" + """Package and upload code to S3. + + Args: + entry_point (str): Path to the entry point script. Must be an existing file. + source_dir (str): Path to the source directory containing code dependencies. + If None, the directory containing entry_point is used. + + .. warning:: + When ``source_dir`` is None and ``entry_point`` is a relative path, + the entire directory containing the entry point will be packaged. + This may include unintended files such as credentials, configuration + files, or other sensitive data. It is **strongly recommended** to + explicitly set ``source_dir`` to a directory containing only the + files needed for the processing job. + + requirements (str): Path to a requirements.txt file relative to source_dir. + If provided, the requirements file will be included in the packaged tar. + job_name (str): The processing job name, used for S3 path construction. + kms_key (str): The ARN of the KMS key for encrypting the uploaded code. + + Returns: + str: The S3 URI of the uploaded sourcedir.tar.gz. + + Raises: + ValueError: If entry_point does not exist or is not a file, or if + source_dir does not exist. + """ import tarfile import tempfile - # If source_dir is not provided, use the directory containing entry_point - if source_dir is None: - if os.path.isabs(entry_point): - source_dir = os.path.dirname(entry_point) - else: - source_dir = os.path.dirname(os.path.abspath(entry_point)) - - # Resolve source_dir to absolute path - if not os.path.isabs(source_dir): - source_dir = os.path.abspath(source_dir) + # If source_dir is explicitly provided, validate it first + if source_dir is not None: + if not os.path.isabs(source_dir): + source_dir = os.path.abspath(source_dir) + if not os.path.exists(source_dir): + raise ValueError(f"source_dir does not exist: {source_dir}") + + # Validate entry_point exists and is a file + entry_point_path = ( + os.path.abspath(entry_point) + if not os.path.isabs(entry_point) + else entry_point + ) + if not os.path.exists(entry_point_path): + raise ValueError( + f"entry_point does not exist: {entry_point}. " + "Please provide a valid path to an existing file." + ) + if not os.path.isfile(entry_point_path): + raise ValueError( + f"entry_point is not a file: {entry_point}. " + "Please provide a path to a file, not a directory." + ) - if not os.path.exists(source_dir): - raise ValueError(f"source_dir does not exist: {source_dir}") + # If source_dir is not provided, use the directory containing entry_point. + # Warn the user since this may package unintended files (including credentials). + if source_dir is None: + source_dir = os.path.dirname(entry_point_path) + logger.warning( + "source_dir is None, defaulting to the directory containing entry_point: '%s'. " + "This may package unintended files including credentials or sensitive data. " + "It is strongly recommended to explicitly set source_dir to a directory " + "containing only the files needed for the processing job.", + source_dir, + ) - # Create tar.gz with source_dir contents - with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp: - with tarfile.open(tmp.name, "w:gz") as tar: + # Create tar.gz with source_dir contents. + # We use delete=False and manage cleanup in a try/finally block. + # The file handle is closed inside the `with` block before we use the + # path with tarfile, which is safe on all platforms. + tmp_name = None + try: + with tempfile.NamedTemporaryFile( + suffix=".tar.gz", delete=False + ) as tmp: + tmp_name = tmp.name + + # Track all files added (using their arcnames) to avoid duplicates. + # We collect arcnames recursively so subdirectory contents are tracked. + added_arcnames = set() + + with tarfile.open(tmp_name, "w:gz") as tar: # Add all files from source_dir to the root of the tar for item in os.listdir(source_dir): item_path = os.path.join(source_dir, item) tar.add(item_path, arcname=item) + # Track this item and all nested paths for duplicate detection + if os.path.isdir(item_path): + for root, dirs, files in os.walk(item_path): + rel_root = os.path.relpath(root, source_dir) + added_arcnames.add(rel_root) + for f in files: + added_arcnames.add( + os.path.join(rel_root, f) + ) + else: + added_arcnames.add(item) + + # Include requirements file if specified and not already included + # via source_dir listing (e.g., if requirements points to a file + # outside source_dir or in a subdirectory with a different arcname) + if requirements: + req_path = os.path.join(source_dir, requirements) + if ( + os.path.isfile(req_path) + and requirements not in added_arcnames + ): + tar.add(req_path, arcname=requirements) # Upload to S3 s3_uri = s3.s3_path_join( @@ -1172,15 +1322,17 @@ def _package_code( ) # Upload the tar file directly to S3 - s3.S3Uploader.upload_string_as_file_body( - body=open(tmp.name, "rb").read(), + s3.S3Uploader.upload( + local_path=tmp_name, desired_s3_uri=s3_uri, kms_key=kms_key, sagemaker_session=self.sagemaker_session, ) + finally: + if tmp_name and os.path.exists(tmp_name): + os.unlink(tmp_name) - os.unlink(tmp.name) - return s3_uri + return s3_uri @_telemetry_emitter(feature=Feature.PROCESSING, func_name="FrameworkProcessor.run") @runnable_by_pipeline @@ -1238,18 +1390,31 @@ def run( ) # Submit a processing job. - return super().run( - code=s3_runproc_sh, + # Note: We call ScriptProcessor.run directly (bypassing its decorators) to avoid + # double telemetry/pipeline interception since this method is already decorated. + normalized_inputs, normalized_outputs = self._normalize_args( + job_name=job_name, + arguments=arguments, inputs=inputs, outputs=outputs, - arguments=arguments, - wait=wait, - logs=logs, - job_name=job_name, - experiment_config=experiment_config, + code=s3_runproc_sh, kms_key=kms_key, ) + experiment_config = check_and_get_run_experiment_config(experiment_config) + self.latest_job = self._start_new( + inputs=normalized_inputs, + outputs=normalized_outputs, + experiment_config=experiment_config, + ) + + from sagemaker.core.workflow.pipeline_context import PipelineSession + + if not isinstance(self.sagemaker_session, PipelineSession): + self.jobs.append(self.latest_job) + if wait: + self.latest_job.wait(logs=logs) + def _pack_and_upload_code( self, code, @@ -1279,13 +1444,16 @@ def _pack_and_upload_code( entrypoint_s3_uri = s3_payload.replace("sourcedir.tar.gz", "runproc.sh") - # Upload the CodeArtifact-aware install_requirements script alongside the source code - import sagemaker.core.utils.install_requirements as _ir_mod - - install_req_s3_uri = s3_payload.replace("sourcedir.tar.gz", "install_requirements.py") + # Upload the CodeArtifact-aware install_requirements script alongside the source code. + install_req_s3_uri = s3_payload.replace( + "sourcedir.tar.gz", "install_requirements.py" + ) evaluated_kms_key = kms_key if kms_key else self.output_kms_key + + install_req_content = self._load_install_requirements_content() + s3.S3Uploader.upload_string_as_file_body( - body=open(_ir_mod.__file__, "r").read(), + body=install_req_content, desired_s3_uri=install_req_s3_uri, kms_key=evaluated_kms_key, sagemaker_session=self.sagemaker_session, @@ -1299,15 +1467,70 @@ def _pack_and_upload_code( return s3_runproc_sh, inputs, job_name + @staticmethod + def _load_install_requirements_content(): + """Load the install_requirements.py module content. + + Uses pkgutil.get_data for robustness when the package may be distributed + as a zip/wheel without extracted files. Falls back to reading from __file__ + if pkgutil.get_data returns None. + + Returns: + str: The content of install_requirements.py as a string. + + Raises: + RuntimeError: If the module cannot be located in any way (e.g., in + frozen or zipped distributions where both __file__ is None and + pkgutil.get_data returns None). + """ + import pkgutil + import sagemaker.core.utils.install_requirements as _ir_mod + + _ir_mod_file = getattr(_ir_mod, "__file__", None) + _ir_mod_package = _ir_mod.__package__ or _ir_mod.__name__ + + # Determine the resource name for pkgutil.get_data + if _ir_mod_file is not None: + resource_name = os.path.basename(_ir_mod_file) + else: + resource_name = "install_requirements.py" + + # Try pkgutil.get_data first (works in zipped/wheel distributions) + try: + install_req_data = pkgutil.get_data(_ir_mod_package, resource_name) + except (OSError, FileNotFoundError): + install_req_data = None + + if install_req_data is not None: + return install_req_data.decode("utf-8") + + # Fallback: read from file directly + if _ir_mod_file is not None and os.path.isfile(_ir_mod_file): + with open(_ir_mod_file, "r") as f: + return f.read() + + raise RuntimeError( + "Cannot locate install_requirements.py: module has no __file__ " + "attribute and pkgutil.get_data returned None. This may occur " + "in frozen or zipped distributions where the resource is not accessible." + ) + def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput]: - """Add payload sourcedir.tar.gz to processing input.""" + """Add payload sourcedir.tar.gz to processing input. + + Note: This method creates a deep copy of the inputs list to avoid mutating + the caller's original ProcessingInput objects. + """ if inputs is None: inputs = [] - # make a shallow copy of user inputs - patched_inputs = copy(inputs) + # Deep copy to avoid mutating the caller's original ProcessingInput objects + patched_inputs = deepcopy(inputs) - # Extract the directory path from the s3_payload (remove the filename) + # Extract the directory path from the s3_payload (remove the filename). + # The trailing '/' is intentional here as it is required for the S3Prefix data type + # in ProcessingS3Input to correctly identify the prefix for downloading all objects + # under that path. s3_code_dir = s3_payload.rsplit("/", 1)[0] + "/" patched_inputs.append( @@ -1367,16 +1590,16 @@ def _generate_framework_script(self, user_script: str) -> str: return dedent( """\ #!/bin/bash - + # Exit on any error. SageMaker uses error code to mark failed job. set -e cd /opt/ml/processing/input/code/ - + # Debug: List files before extraction echo "Files in /opt/ml/processing/input/code/ before extraction:" ls -la - + # Extract source code if [ -f sourcedir.tar.gz ]; then tar -xzf sourcedir.tar.gz @@ -1388,10 +1611,7 @@ def _generate_framework_script(self, user_script: str) -> str: fi if [[ -f 'requirements.txt' ]]; then - # Some py3 containers has typing, which may breaks pip install - pip uninstall --yes typing - - python3 /opt/ml/processing/input/code/install_requirements.py requirements.txt + {entry_point_command} /opt/ml/processing/input/code/install_requirements.py requirements.txt fi {entry_point_command} {entry_point} "$@" @@ -1615,10 +1835,10 @@ def logs_for_processing_job(sagemaker_session, job_name, wait=False, poll=10): status = description["ProcessingJobStatus"] if status in ("Completed", "Failed", "Stopped"): - print() + logger.info("") state = LogState.JOB_COMPLETE if wait: _check_job_status(job_name, description, "ProcessingJobStatus") if dot: - print() + logger.info("")