diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_blob_storage_helper.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_blob_storage_helper.py index 4cd8ced1e9a7..5db5008ccefe 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_blob_storage_helper.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_blob_storage_helper.py @@ -246,9 +246,16 @@ def download( try: my_list = list(self.container_client.list_blobs(name_starts_with=starts_with, include="metadata")) download_size_in_mb = 0 + resolved_destination = Path(destination).resolve() for item in my_list: blob_name = item.name[len(starts_with) :].lstrip("/") or Path(starts_with).name target_path = Path(destination, blob_name).resolve() + if target_path != resolved_destination and not str(target_path).startswith( + str(resolved_destination) + os.sep + ): + raise ValueError( + f"Blob name contains a path traversal entry and cannot be downloaded safely: {item.name}" + ) if _blob_is_hdi_folder(item): target_path.mkdir(parents=True, exist_ok=True) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_fileshare_storage_helper.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_fileshare_storage_helper.py index 01da568cb59a..6c4bd120d096 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_fileshare_storage_helper.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_fileshare_storage_helper.py @@ -403,16 +403,30 @@ def recursive_download( files = [item for item in items if not item["is_directory"]] folders = [item for item in items if item["is_directory"]] + resolved_destination = Path(destination).resolve() for f in files: Path(destination).mkdir(parents=True, exist_ok=True) file_name = f["name"] + local_path = Path(destination, file_name).resolve() + if local_path != resolved_destination and not str(local_path).startswith( + str(resolved_destination) + os.sep + ): + raise ValueError( + f"File name contains a path traversal entry and cannot be downloaded safely: {file_name}" + ) file_client = client.get_file_client(file_name) file_content = file_client.download_file(max_concurrency=max_concurrency) - local_path = Path(destination, file_name) with open(local_path, "wb") as file_data: file_data.write(file_content.readall()) for f in folders: + sub_destination = Path(destination, f["name"]).resolve() + if sub_destination != resolved_destination and not str(sub_destination).startswith( + str(resolved_destination) + os.sep + ): + raise ValueError( + f"Directory name contains a path traversal entry and cannot be downloaded safely: {f['name']}" + ) sub_client = client.get_subdirectory_client(f["name"]) destination = "/".join((destination, f["name"])) recursive_download(sub_client, destination=destination, max_concurrency=max_concurrency) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_gen2_storage_helper.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_gen2_storage_helper.py index 7cd3b6c6dab0..0db948a14a0f 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_gen2_storage_helper.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_gen2_storage_helper.py @@ -206,9 +206,16 @@ def download(self, starts_with: str, destination: Union[str, os.PathLike] = Path try: mylist = self.file_system_client.get_paths(path=starts_with) download_size_in_mb = 0 + resolved_destination = Path(destination).resolve() for item in mylist: file_name = item.name[len(starts_with) :].lstrip("/") or Path(starts_with).name - target_path = Path(destination, file_name) + target_path = Path(destination, file_name).resolve() + if target_path != resolved_destination and not str(target_path).startswith( + str(resolved_destination) + os.sep + ): + raise ValueError( + f"Path name contains a path traversal entry and cannot be downloaded safely: {item.name}" + ) if item.is_directory: target_path.mkdir(parents=True, exist_ok=True) diff --git a/sdk/ml/azure-ai-ml/tests/internal_utils/unittests/test_storage_download_traversal.py b/sdk/ml/azure-ai-ml/tests/internal_utils/unittests/test_storage_download_traversal.py new file mode 100644 index 000000000000..c3be802ecdb2 --- /dev/null +++ b/sdk/ml/azure-ai-ml/tests/internal_utils/unittests/test_storage_download_traversal.py @@ -0,0 +1,74 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from azure.ai.ml.exceptions import MlException + +_CLOUD = {"storage_endpoint": "core.windows.net"} + + +def _named(name, is_directory=False): + item = MagicMock() + item.name = name + item.is_directory = is_directory + return item + + +@pytest.mark.unittest +class TestStorageDownloadTraversal: + """Server-controlled blob/file names with ``..`` segments must not escape the destination.""" + + def test_blob_download_rejects_path_traversal(self, tmp_path): + from azure.ai.ml._artifacts._blob_storage_helper import BlobStorageClient + + with patch("azure.ai.ml._artifacts._blob_storage_helper.BlobServiceClient"): + client = BlobStorageClient( + credential="cred", account_url="https://acct.blob.core.windows.net", container_name="c" + ) + client.container_client = MagicMock() + client.container_client.list_blobs.return_value = [_named("asset/../escaped.txt")] + blob_content = client.container_client.download_blob.return_value + blob_content.size = 1 + blob_content.content_as_bytes.return_value = b"data" + + dest = tmp_path / "dest" + dest.mkdir() + with patch("azure.ai.ml._artifacts._blob_storage_helper._blob_is_hdi_folder", return_value=False), patch( + "azure.ai.ml._artifacts._blob_storage_helper._get_cloud_details", return_value=_CLOUD + ): + with pytest.raises(MlException): + client.download(starts_with="asset/", destination=str(dest)) + assert not (tmp_path / "escaped.txt").exists() + + def test_gen2_download_rejects_path_traversal(self, tmp_path): + from azure.ai.ml._artifacts._gen2_storage_helper import Gen2StorageClient + + with patch("azure.ai.ml._artifacts._gen2_storage_helper.DataLakeServiceClient"): + client = Gen2StorageClient( + credential="cred", file_system="fs", account_url="https://acct.dfs.core.windows.net" + ) + client.file_system_client = MagicMock() + client.file_system_client.get_paths.return_value = [_named("asset/../escaped.txt")] + file_client = client.file_system_client.get_file_client.return_value + file_client.get_file_properties.return_value.size = 1 + file_client.download_file.return_value.readall.return_value = b"data" + + dest = tmp_path / "dest" + dest.mkdir() + with patch("azure.ai.ml._artifacts._gen2_storage_helper._get_cloud_details", return_value=_CLOUD): + with pytest.raises(MlException): + client.download(starts_with="asset/", destination=str(dest)) + assert not (tmp_path / "escaped.txt").exists() + + def test_fileshare_recursive_download_rejects_path_traversal(self, tmp_path): + from azure.ai.ml._artifacts._fileshare_storage_helper import recursive_download + + client = MagicMock() + client.list_directories_and_files.return_value = [{"name": "../escaped.txt", "is_directory": False}] + client.get_file_client.return_value.download_file.return_value.readall.return_value = b"data" + + dest = tmp_path / "dest" + dest.mkdir() + with pytest.raises(MlException): + recursive_download(client, destination=str(dest), max_concurrency=1) + assert not (tmp_path / "escaped.txt").exists()