diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index 8a8134f5ea..d3a57498ad 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -1747,7 +1747,11 @@ def _is_bad_path(path, base): bool: True if the path is not rooted under the base directory, False otherwise. """ # joinpath will ignore base if path is absolute - return not _get_resolved_path(joinpath(base, path)).startswith(base) + resolved = _get_resolved_path(joinpath(base, path)) + try: + return os.path.commonpath([resolved, base]) != base + except ValueError: + return True def _is_bad_link(info, base): @@ -1767,22 +1771,23 @@ def _is_bad_link(info, base): return _is_bad_path(info.linkname, base=tip) -def _get_safe_members(members): +def _get_safe_members(members, base): """A generator that yields members that are safe to extract. It filters out bad paths and bad links. Args: members (list): A list of members to check. + base (str): The base directory for extraction. Yields: tarfile.TarInfo: The tar file info. """ - base = _get_resolved_path("") - for file_info in members: if _is_bad_path(file_info.name, base): logger.error("%s is blocked (illegal path)", file_info.name) + elif file_info.isdev() or file_info.isfifo(): + logger.error("%s is blocked: special file", file_info.name) elif file_info.issym() and _is_bad_link(file_info, base): logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname) elif file_info.islnk() and _is_bad_link(file_info, base): @@ -1810,7 +1815,11 @@ def _validate_extracted_paths(extract_path): for dir_name in dirs: dir_path = os.path.join(root, dir_name) resolved = _get_resolved_path(dir_path) - if not resolved.startswith(base): + try: + is_within_base = os.path.commonpath([resolved, base]) == base + except ValueError: + is_within_base = False + if not is_within_base: logger.error("Extracted directory escaped extraction path: %s", dir_path) raise ValueError(f"Extracted path outside expected directory: {dir_path}") @@ -1818,7 +1827,11 @@ def _validate_extracted_paths(extract_path): for file_name in files: file_path = os.path.join(root, file_name) resolved = _get_resolved_path(file_path) - if not resolved.startswith(base): + try: + is_within_base = os.path.commonpath([resolved, base]) == base + except ValueError: + is_within_base = False + if not is_within_base: logger.error("Extracted file escaped extraction path: %s", file_path) raise ValueError(f"Extracted path outside expected directory: {file_path}") @@ -1842,7 +1855,8 @@ def custom_extractall_tarfile(tar, extract_path): if hasattr(tarfile, "data_filter"): tar.extractall(path=extract_path, filter="data") else: - tar.extractall(path=extract_path, members=_get_safe_members(tar)) + base = _get_resolved_path(extract_path) + tar.extractall(path=extract_path, members=_get_safe_members(tar.getmembers(), base)) # Re-validate extracted paths to catch symlink race conditions _validate_extracted_paths(extract_path) diff --git a/sagemaker-core/tests/unit/test_common_utils.py b/sagemaker-core/tests/unit/test_common_utils.py index 3f7fc94f67..b3275e4156 100644 --- a/sagemaker-core/tests/unit/test_common_utils.py +++ b/sagemaker-core/tests/unit/test_common_utils.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import pytest +import io import time import tempfile import os @@ -1227,6 +1228,46 @@ def test_custom_extractall_tarfile_basic(self, tmp_path): assert (extract_path / "file.txt").exists() + def test_custom_extractall_tarfile_blocks_sibling_prefix_escape(self, tmp_path, monkeypatch): + """Test fallback extraction blocks sibling paths with the same directory prefix.""" + from sagemaker.core.common_utils import custom_extractall_tarfile + + monkeypatch.delattr(tarfile, "data_filter", raising=False) + extract_path = tmp_path / "src" + extract_path.mkdir() + escaped_path = tmp_path / "src_evil" / "escaped.txt" + stream = io.BytesIO() + with tarfile.open(fileobj=stream, mode="w") as tar: + data = b"outside" + file_info = tarfile.TarInfo(name="../src_evil/escaped.txt") + file_info.size = len(data) + tar.addfile(file_info, io.BytesIO(data)) + stream.seek(0) + + with tarfile.open(fileobj=stream, mode="r") as tar: + custom_extractall_tarfile(tar, str(extract_path)) + + assert not escaped_path.exists() + + def test_custom_extractall_tarfile_blocks_special_files(self, tmp_path, monkeypatch): + """Test fallback extraction blocks special files like the data filter.""" + from sagemaker.core.common_utils import custom_extractall_tarfile + + monkeypatch.delattr(tarfile, "data_filter", raising=False) + extract_path = tmp_path / "extract" + extract_path.mkdir() + stream = io.BytesIO() + with tarfile.open(fileobj=stream, mode="w") as tar: + fifo_info = tarfile.TarInfo(name="blocked_fifo") + fifo_info.type = tarfile.FIFOTYPE + tar.addfile(fifo_info) + stream.seek(0) + + with tarfile.open(fileobj=stream, mode="r") as tar: + custom_extractall_tarfile(tar, str(extract_path)) + + assert not (extract_path / "blocked_fifo").exists() + class TestCanModelPackageSourceUriAutopopulate: """Test can_model_package_source_uri_autopopulate function."""