Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,16 @@ def download(
try:
my_list = list(self.container_client.list_blobs(name_starts_with=starts_with, include="metadata"))
download_size_in_mb = 0
resolved_destination = Path(destination).resolve()
for item in my_list:
blob_name = item.name[len(starts_with) :].lstrip("/") or Path(starts_with).name
target_path = Path(destination, blob_name).resolve()
if target_path != resolved_destination and not str(target_path).startswith(
str(resolved_destination) + os.sep
):
Comment on lines +249 to +255
raise ValueError(
f"Blob name contains a path traversal entry and cannot be downloaded safely: {item.name}"
)

if _blob_is_hdi_folder(item):
target_path.mkdir(parents=True, exist_ok=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,16 +403,30 @@ def recursive_download(
files = [item for item in items if not item["is_directory"]]
folders = [item for item in items if item["is_directory"]]

resolved_destination = Path(destination).resolve()
for f in files:
Path(destination).mkdir(parents=True, exist_ok=True)
file_name = f["name"]
local_path = Path(destination, file_name).resolve()
if local_path != resolved_destination and not str(local_path).startswith(
str(resolved_destination) + os.sep
):
raise ValueError(
f"File name contains a path traversal entry and cannot be downloaded safely: {file_name}"
)
file_client = client.get_file_client(file_name)
file_content = file_client.download_file(max_concurrency=max_concurrency)
local_path = Path(destination, file_name)
with open(local_path, "wb") as file_data:
file_data.write(file_content.readall())

for f in folders:
sub_destination = Path(destination, f["name"]).resolve()
if sub_destination != resolved_destination and not str(sub_destination).startswith(
str(resolved_destination) + os.sep
):
raise ValueError(
f"Directory name contains a path traversal entry and cannot be downloaded safely: {f['name']}"
)
sub_client = client.get_subdirectory_client(f["name"])
destination = "/".join((destination, f["name"]))
recursive_download(sub_client, destination=destination, max_concurrency=max_concurrency)
Comment on lines 430 to 432
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,16 @@ def download(self, starts_with: str, destination: Union[str, os.PathLike] = Path
try:
mylist = self.file_system_client.get_paths(path=starts_with)
download_size_in_mb = 0
resolved_destination = Path(destination).resolve()
for item in mylist:
file_name = item.name[len(starts_with) :].lstrip("/") or Path(starts_with).name
target_path = Path(destination, file_name)
target_path = Path(destination, file_name).resolve()
if target_path != resolved_destination and not str(target_path).startswith(
str(resolved_destination) + os.sep
):
raise ValueError(
f"Path name contains a path traversal entry and cannot be downloaded safely: {item.name}"
)
Comment on lines +209 to +218

if item.is_directory:
target_path.mkdir(parents=True, exist_ok=True)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from unittest.mock import MagicMock, patch

import pytest

from azure.ai.ml.exceptions import MlException

_CLOUD = {"storage_endpoint": "core.windows.net"}


def _named(name, is_directory=False):
item = MagicMock()
item.name = name
item.is_directory = is_directory
return item


@pytest.mark.unittest
class TestStorageDownloadTraversal:
"""Server-controlled blob/file names with ``..`` segments must not escape the destination."""

def test_blob_download_rejects_path_traversal(self, tmp_path):
from azure.ai.ml._artifacts._blob_storage_helper import BlobStorageClient

with patch("azure.ai.ml._artifacts._blob_storage_helper.BlobServiceClient"):
client = BlobStorageClient(
credential="cred", account_url="https://acct.blob.core.windows.net", container_name="c"
)
client.container_client = MagicMock()
client.container_client.list_blobs.return_value = [_named("asset/../escaped.txt")]
blob_content = client.container_client.download_blob.return_value
blob_content.size = 1
blob_content.content_as_bytes.return_value = b"data"

dest = tmp_path / "dest"
dest.mkdir()
with patch("azure.ai.ml._artifacts._blob_storage_helper._blob_is_hdi_folder", return_value=False), patch(
"azure.ai.ml._artifacts._blob_storage_helper._get_cloud_details", return_value=_CLOUD
):
with pytest.raises(MlException):
client.download(starts_with="asset/", destination=str(dest))
assert not (tmp_path / "escaped.txt").exists()

def test_gen2_download_rejects_path_traversal(self, tmp_path):
from azure.ai.ml._artifacts._gen2_storage_helper import Gen2StorageClient

with patch("azure.ai.ml._artifacts._gen2_storage_helper.DataLakeServiceClient"):
client = Gen2StorageClient(
credential="cred", file_system="fs", account_url="https://acct.dfs.core.windows.net"
)
client.file_system_client = MagicMock()
client.file_system_client.get_paths.return_value = [_named("asset/../escaped.txt")]
file_client = client.file_system_client.get_file_client.return_value
file_client.get_file_properties.return_value.size = 1
file_client.download_file.return_value.readall.return_value = b"data"

dest = tmp_path / "dest"
dest.mkdir()
with patch("azure.ai.ml._artifacts._gen2_storage_helper._get_cloud_details", return_value=_CLOUD):
with pytest.raises(MlException):
client.download(starts_with="asset/", destination=str(dest))
assert not (tmp_path / "escaped.txt").exists()

def test_fileshare_recursive_download_rejects_path_traversal(self, tmp_path):
from azure.ai.ml._artifacts._fileshare_storage_helper import recursive_download

client = MagicMock()
client.list_directories_and_files.return_value = [{"name": "../escaped.txt", "is_directory": False}]
client.get_file_client.return_value.download_file.return_value.readall.return_value = b"data"

dest = tmp_path / "dest"
dest.mkdir()
with pytest.raises(MlException):
recursive_download(client, destination=str(dest), max_concurrency=1)
assert not (tmp_path / "escaped.txt").exists()