Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions sagemaker-core/src/sagemaker/core/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,11 @@ def _is_bad_path(path, base):
bool: True if the path is not rooted under the base directory, False otherwise.
"""
# joinpath will ignore base if path is absolute
return not _get_resolved_path(joinpath(base, path)).startswith(base)
resolved = _get_resolved_path(joinpath(base, path))
try:
return os.path.commonpath([resolved, base]) != base
except ValueError:
return True


def _is_bad_link(info, base):
Expand All @@ -1767,22 +1771,23 @@ def _is_bad_link(info, base):
return _is_bad_path(info.linkname, base=tip)


def _get_safe_members(members):
def _get_safe_members(members, base):
"""A generator that yields members that are safe to extract.

It filters out bad paths and bad links.

Args:
members (list): A list of members to check.
base (str): The base directory for extraction.

Yields:
tarfile.TarInfo: The tar file info.
"""
base = _get_resolved_path("")

for file_info in members:
if _is_bad_path(file_info.name, base):
logger.error("%s is blocked (illegal path)", file_info.name)
elif file_info.isdev() or file_info.isfifo():
logger.error("%s is blocked: special file", file_info.name)
elif file_info.issym() and _is_bad_link(file_info, base):
logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname)
elif file_info.islnk() and _is_bad_link(file_info, base):
Expand Down Expand Up @@ -1810,15 +1815,23 @@ def _validate_extracted_paths(extract_path):
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
resolved = _get_resolved_path(dir_path)
if not resolved.startswith(base):
try:
is_within_base = os.path.commonpath([resolved, base]) == base
except ValueError:
is_within_base = False
if not is_within_base:
logger.error("Extracted directory escaped extraction path: %s", dir_path)
raise ValueError(f"Extracted path outside expected directory: {dir_path}")

# Check files
for file_name in files:
file_path = os.path.join(root, file_name)
resolved = _get_resolved_path(file_path)
if not resolved.startswith(base):
try:
is_within_base = os.path.commonpath([resolved, base]) == base
except ValueError:
is_within_base = False
if not is_within_base:
logger.error("Extracted file escaped extraction path: %s", file_path)
raise ValueError(f"Extracted path outside expected directory: {file_path}")

Expand All @@ -1842,7 +1855,8 @@ def custom_extractall_tarfile(tar, extract_path):
if hasattr(tarfile, "data_filter"):
tar.extractall(path=extract_path, filter="data")
else:
tar.extractall(path=extract_path, members=_get_safe_members(tar))
base = _get_resolved_path(extract_path)
tar.extractall(path=extract_path, members=_get_safe_members(tar.getmembers(), base))
# Re-validate extracted paths to catch symlink race conditions
_validate_extracted_paths(extract_path)

Expand Down
41 changes: 41 additions & 0 deletions sagemaker-core/tests/unit/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

import pytest
import io
import time
import tempfile
import os
Expand Down Expand Up @@ -1227,6 +1228,46 @@ def test_custom_extractall_tarfile_basic(self, tmp_path):

assert (extract_path / "file.txt").exists()

def test_custom_extractall_tarfile_blocks_sibling_prefix_escape(self, tmp_path, monkeypatch):
"""Test fallback extraction blocks sibling paths with the same directory prefix."""
from sagemaker.core.common_utils import custom_extractall_tarfile

monkeypatch.delattr(tarfile, "data_filter", raising=False)
extract_path = tmp_path / "src"
extract_path.mkdir()
escaped_path = tmp_path / "src_evil" / "escaped.txt"
stream = io.BytesIO()
with tarfile.open(fileobj=stream, mode="w") as tar:
data = b"outside"
file_info = tarfile.TarInfo(name="../src_evil/escaped.txt")
file_info.size = len(data)
tar.addfile(file_info, io.BytesIO(data))
stream.seek(0)

with tarfile.open(fileobj=stream, mode="r") as tar:
custom_extractall_tarfile(tar, str(extract_path))

assert not escaped_path.exists()

def test_custom_extractall_tarfile_blocks_special_files(self, tmp_path, monkeypatch):
"""Test fallback extraction blocks special files like the data filter."""
from sagemaker.core.common_utils import custom_extractall_tarfile

monkeypatch.delattr(tarfile, "data_filter", raising=False)
extract_path = tmp_path / "extract"
extract_path.mkdir()
stream = io.BytesIO()
with tarfile.open(fileobj=stream, mode="w") as tar:
fifo_info = tarfile.TarInfo(name="blocked_fifo")
fifo_info.type = tarfile.FIFOTYPE
tar.addfile(fifo_info)
stream.seek(0)

with tarfile.open(fileobj=stream, mode="r") as tar:
custom_extractall_tarfile(tar, str(extract_path))

assert not (extract_path / "blocked_fifo").exists()


class TestCanModelPackageSourceUriAutopopulate:
"""Test can_model_package_source_uri_autopopulate function."""
Expand Down
Loading