diff --git a/clients/python/README.md b/clients/python/README.md index 64b4c818..7f80e617 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -91,6 +91,66 @@ Arbitrary key-value pairs can be attached to objects and retrieved on download. session.put(b"payload", metadata={"source": "upload-service"}) ``` +### Multipart Upload API + +For large objects, use multipart uploads to upload parts independently and then +assemble them into a final object. + +**Important:** unlike single-object uploads, multipart uploads do **not** auto-compress. +The caller must pre-compress each part according to the compression set as part of the metadata +when initiating the upload. + +```python +from concurrent.futures import ThreadPoolExecutor + +import zstandard + +from objectstore_client.multipart import MultipartCompleteError + +upload = session.initiate_multipart_upload( + key="my-large-object", + compression="zstd", + metadata={"source": "upload-service"}, +) + +compressor = zstandard.ZstdCompressor() +chunks = [b"part1", b"part2", b"part3", b"part4"] + +def upload_part(part_number: int, data: bytes): + compressed = compressor.compress(data) + return upload.upload_part( + compressed, part_number=part_number, content_length=len(compressed) + ) + +with ThreadPoolExecutor(max_workers=4) as executor: + futures = [ + executor.submit(upload_part, i + 1, chunk) + for i, chunk in enumerate(chunks) + ] + parts = [f.result() for f in futures] + +try: + key = upload.complete(parts) +except MultipartCompleteError: + upload.abort() + raise +``` + +To resume an in-progress multipart upload after a process restart, persist the +`key` and `upload_id`, then reconstruct the upload handle later: + +```python +saved_key = upload.key +saved_upload_id = upload.upload_id + +resumed = session.resume_multipart_upload(saved_key, saved_upload_id) +existing_parts = resumed.list_parts() + +# Upload missing parts... + +key = resumed.complete(new_parts + existing_parts) +``` + ### Authentication If your Objectstore instance enforces authorization, you must configure authentication diff --git a/clients/python/docs/conf.py b/clients/python/docs/conf.py index b0c4d7d3..8b1c9450 100644 --- a/clients/python/docs/conf.py +++ b/clients/python/docs/conf.py @@ -25,3 +25,8 @@ "undoc-members": True, "show-inheritance": True, } + +# Re-exported symbols in __init__.py create duplicate Sphinx targets +# (e.g. objectstore_client.Session vs objectstore_client.client.Session). +# This is the most specific suppression Sphinx supports for that warning. +suppress_warnings = ["ref.python"] diff --git a/clients/python/docs/objectstore_client.rst b/clients/python/docs/objectstore_client.rst index d45997fd..28fdf9a9 100644 --- a/clients/python/docs/objectstore_client.rst +++ b/clients/python/docs/objectstore_client.rst @@ -25,6 +25,14 @@ objectstore\_client.client module :show-inheritance: :undoc-members: +objectstore\_client.errors module +--------------------------------- + +.. automodule:: objectstore_client.errors + :members: + :show-inheritance: + :undoc-members: + objectstore\_client.metadata module ----------------------------------- @@ -41,6 +49,14 @@ objectstore\_client.metrics module :show-inheritance: :undoc-members: +objectstore\_client.multipart module +------------------------------------ + +.. automodule:: objectstore_client.multipart + :members: + :show-inheritance: + :undoc-members: + objectstore\_client.scope module -------------------------------- diff --git a/clients/python/src/objectstore_client/__init__.py b/clients/python/src/objectstore_client/__init__.py index 41435258..fa26f417 100644 --- a/clients/python/src/objectstore_client/__init__.py +++ b/clients/python/src/objectstore_client/__init__.py @@ -2,10 +2,10 @@ from objectstore_client.client import ( Client, GetResponse, - RequestError, Session, Usecase, ) +from objectstore_client.errors import RequestError from objectstore_client.metadata import ( Compression, ExpirationPolicy, diff --git a/clients/python/src/objectstore_client/client.py b/clients/python/src/objectstore_client/client.py index 74f18243..a24b0397 100644 --- a/clients/python/src/objectstore_client/client.py +++ b/clients/python/src/objectstore_client/client.py @@ -13,6 +13,7 @@ from objectstore_client import utils from objectstore_client.auth import Permission, TokenGenerator, TokenProvider +from objectstore_client.errors import raise_for_status from objectstore_client.metadata import ( HEADER_EXPIRATION, HEADER_META_PREFIX, @@ -27,6 +28,7 @@ NoOpMetricsBackend, measure_storage_operation, ) +from objectstore_client.multipart import MultipartUpload from objectstore_client.scope import Scope @@ -35,15 +37,6 @@ class GetResponse(NamedTuple): payload: IO[bytes] -class RequestError(Exception): - """Exception raised if an API call to Objectstore fails.""" - - def __init__(self, message: str, status: int, response: str): - super().__init__(message) - self.status = status - self.response = response - - class Usecase: """ An identifier for a workload in Objectstore, along with defaults to use for all @@ -281,6 +274,25 @@ def _make_url(self, key: str | None, full: bool = False) -> str: return f"http://{self._pool.host}:{self._pool.port}{path}" return path + def _make_multipart_url( + self, + action: str | None, + key: str | None, + query: str | None = None, + ) -> str: + if action == "parts": + resource = "objects:multipart:parts" + elif action == "complete": + resource = "objects:multipart:complete" + else: + resource = "objects:multipart" + + relative_path = f"/v1/{resource}/{self._usecase.name}/{self._scope}/{key or ''}" + path = self._base_path.rstrip("/") + relative_path + if query: + return f"{path}?{query}" + return path + def put( self, contents: bytes | IO[bytes], @@ -445,12 +457,74 @@ def delete(self, key: str) -> None: ) raise_for_status(response) + def initiate_multipart_upload( + self, + *, + key: str | None = None, + compression: Compression | Literal["none"] | None = None, + content_type: str | None = None, + metadata: dict[str, str] | None = None, + expiration_policy: ExpirationPolicy | None = None, + origin: str | None = None, + ) -> MultipartUpload: + """ + Initiates a multipart upload. -def raise_for_status(response: urllib3.BaseHTTPResponse) -> None: - if response.status >= 400: - res = (response.data or response.read() or b"").decode("utf-8", "replace") - raise RequestError( - f"Objectstore request failed with status {response.status}", - response.status, - res, - ) + Returns a :class:`~objectstore_client.multipart.MultipartUpload` handle + that can be used to upload parts, list parts, complete, or abort. + + **Important:** unlike :meth:`put`, the ``compression`` parameter only + records the compression algorithm in the object's metadata. + The caller is responsible for compressing each part in accordance with the + chosen algorithm before passing it to + :meth:`~objectstore_client.multipart.MultipartUpload.upload_part`. + """ + if compression and compression not in ("none", "zstd"): + raise ValueError(f"Invalid compression: {compression}") + + headers = self._make_headers() + + compression = compression or self._usecase._compression + if compression and compression != "none": + headers["Content-Encoding"] = compression + + if content_type: + headers["Content-Type"] = content_type + + expiration_policy = expiration_policy or self._usecase._expiration_policy + if expiration_policy: + headers[HEADER_EXPIRATION] = format_expiration(expiration_policy) + + if origin: + headers[HEADER_ORIGIN] = origin + + if metadata: + for k, v in metadata.items(): + headers[f"{HEADER_META_PREFIX}{k}"] = v + + if key == "": + key = None + + with measure_storage_operation( + self._metrics_backend, "multipart.initiate", self._usecase.name + ): + response = self._pool.request( + "POST" if not key else "PUT", + self._make_multipart_url(None, key), + headers=headers, + preload_content=True, + decode_content=True, + ) + raise_for_status(response) + res = response.json() + return MultipartUpload(self, res["key"], res["upload_id"]) + + def resume_multipart_upload(self, key: str, upload_id: str) -> MultipartUpload: + """ + Reconstructs a multipart upload handle. + + This does not make any network calls. + Use it to resume an upload after a process restart or to + continue an upload started elsewhere. + """ + return MultipartUpload(self, key, upload_id) diff --git a/clients/python/src/objectstore_client/errors.py b/clients/python/src/objectstore_client/errors.py new file mode 100644 index 00000000..099ffa17 --- /dev/null +++ b/clients/python/src/objectstore_client/errors.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import urllib3 + + +class RequestError(Exception): + """Exception raised if an API call to Objectstore fails.""" + + def __init__(self, message: str, status: int, response: str): + super().__init__(message) + self.status = status + self.response = response + + +def raise_for_status(response: urllib3.BaseHTTPResponse) -> None: + if response.status >= 400: + res = (response.data or response.read() or b"").decode("utf-8", "replace") + raise RequestError( + f"Objectstore request failed with status {response.status}", + response.status, + res, + ) diff --git a/clients/python/src/objectstore_client/metrics.py b/clients/python/src/objectstore_client/metrics.py index 801875ca..6c7e59d0 100644 --- a/clients/python/src/objectstore_client/metrics.py +++ b/clients/python/src/objectstore_client/metrics.py @@ -83,6 +83,7 @@ def __init__(self, backend: MetricsBackend, operation: str, usecase: str): # These may be set during or after the enclosed operation self.start: int | None = None self.elapsed: float | None = None + self.size: int | None = None self.uncompressed_size: int | None = None self.compressed_size: int | None = None self.compression: str = "unknown" @@ -101,6 +102,13 @@ def record_uncompressed_size(self, value: int) -> None: ) self.uncompressed_size = value + def record_size(self, value: int) -> None: + tags = {"usecase": self.usecase} + self.backend.distribution( + f"storage.{self.operation}.size", value, tags=tags, unit="byte" + ) + self.size = value + def record_compressed_size(self, value: int, compression: str = "unknown") -> None: tags = {"usecase": self.usecase, "compression": compression} self.backend.distribution( @@ -124,14 +132,18 @@ def maybe_record_throughputs(self) -> None: if not self.elapsed or self.elapsed <= 0: return None - sizes = [] + sizes: list[tuple[int, str | None]] = [] + if self.size: + sizes.append((self.size, None)) if self.uncompressed_size: sizes.append((self.uncompressed_size, "none")) if self.compressed_size: sizes.append((self.compressed_size, self.compression)) for size, compression in sizes: - tags = {"usecase": self.usecase, "compression": compression} + tags: dict[str, str] = {"usecase": self.usecase} + if compression is not None: + tags["compression"] = compression self.backend.distribution( f"storage.{self.operation}.throughput", size / self.elapsed, tags=tags ) diff --git a/clients/python/src/objectstore_client/multipart.py b/clients/python/src/objectstore_client/multipart.py new file mode 100644 index 00000000..62ee8134 --- /dev/null +++ b/clients/python/src/objectstore_client/multipart.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import base64 +import json +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import datetime +from io import BytesIO +from typing import IO, TYPE_CHECKING +from urllib.parse import urlencode + +from objectstore_client.errors import RequestError, raise_for_status +from objectstore_client.metrics import measure_storage_operation + +if TYPE_CHECKING: + from objectstore_client.client import Session + + +@dataclass +class CompletePart: + """A reference to an uploaded part, used when completing a multipart upload.""" + + part_number: int + etag: str + + +@dataclass +class PartInfo: + """Information about an uploaded part""" + + part_number: int + etag: str + last_modified: datetime + size: int + + +class MultipartCompleteError(RequestError): + """Error returned as part of a multipart:complete response's body.""" + + def __init__(self, code: str, message: str): + super().__init__( + f"Multipart upload completion failed ({code}): {message}", + status=200, + response=message, + ) + self.code = code + + +class MultipartUpload: + """ + Handle for an in-progress multipart upload. + + Create via :meth:`~objectstore_client.client.Session.initiate_multipart_upload` or + :meth:`~objectstore_client.client.Session.resume_multipart_upload`. + """ + + def __init__(self, session: Session, key: str, upload_id: str): + self._session = session + self._key = key + self._upload_id = upload_id + + @property + def key(self) -> str: + return self._key + + @property + def upload_id(self) -> str: + return self._upload_id + + def put_part( + self, + contents: bytes | IO[bytes], + *, + part_number: int, + content_length: int, + content_md5: bytes | None = None, + ) -> CompletePart: + """ + Uploads a single part. + + IMPORTANT: Unlike + :meth:`~objectstore_client.client.Session.put`, + this does **not** automatically compress `contents`. + The caller must pre-compress each part according to the + compression set as part of the metadata when initiating + the upload. + + Args: + contents: The part data. If this upload was initiated + with compression, this must be pre-compressed. + part_number: 1-indexed part number. + content_length: The length in bytes of the payload + being uploaded. If this upload was initiated with + compression, this must be the post-compression + length. + content_md5: Optional raw MD5 digest of `contents`. + """ + if isinstance(contents, bytes): + if len(contents) != content_length: + raise ValueError( + "content_length must match the size of the provided payload" + ) + body: bytes | IO[bytes] = BytesIO(contents) + else: + body = contents + + if content_md5 is not None and len(content_md5) != 16: + raise ValueError("content_md5 must be exactly 16 bytes") + + headers = self._session._make_headers() + headers["Content-Length"] = str(content_length) + + if content_md5 is not None: + headers["Content-MD5"] = base64.b64encode(content_md5).decode("ascii") + + query = urlencode( + {"upload_id": self._upload_id, "part_number": str(part_number)} + ) + url = self._session._make_multipart_url("parts", self._key, query) + + with measure_storage_operation( + self._session._metrics_backend, + "multipart.put_part", + self._session._usecase.name, + ) as metric_emitter: + response = self._session._pool.request( + "PUT", + url, + body=body, + headers=headers, + preload_content=True, + decode_content=True, + ) + raise_for_status(response) + res = response.json() + metric_emitter.record_size(content_length) + return CompletePart(part_number=part_number, etag=res["etag"]) + + def list_parts(self) -> list[PartInfo]: + """Lists all uploaded parts.""" + all_parts: list[PartInfo] = [] + marker: int | None = None + + while True: + params: dict[str, str] = {"upload_id": self._upload_id} + if marker is not None: + params["part_number_marker"] = str(marker) + + query = urlencode(params) + url = self._session._make_multipart_url("parts", self._key, query) + headers = self._session._make_headers() + + response = self._session._pool.request( + "GET", + url, + headers=headers, + preload_content=True, + ) + raise_for_status(response) + data = response.json() + + for p in data["parts"]: + all_parts.append( + PartInfo( + part_number=p["part_number"], + etag=p["etag"], + last_modified=datetime.fromisoformat(p["last_modified"]), + size=p["size"], + ) + ) + + if not data["is_truncated"]: + return all_parts + + marker = data.get("next_part_number_marker") + if marker is None: + raise RequestError( + "Server returned is_truncated=true but no next_part_number_marker", + status=200, + response=str(data), + ) + + def abort(self) -> None: + """Aborts this multipart upload, cleaning up server-side state.""" + query = urlencode({"upload_id": self._upload_id}) + url = self._session._make_multipart_url(None, self._key, query) + headers = self._session._make_headers() + + response = self._session._pool.request( + "DELETE", + url, + headers=headers, + ) + raise_for_status(response) + + def complete(self, parts: Sequence[CompletePart | PartInfo]) -> str: + """Completes the multipart upload, assembling all parts into the final object. + + Returns the final object key. + + Raises :class:`MultipartCompleteError` if the server reports an error + during assembly, or :class:`RequestError` if the server returns a non-2XX + response. + """ + query = urlencode({"upload_id": self._upload_id}) + url = self._session._make_multipart_url("complete", self._key, query) + headers = self._session._make_headers() + headers["Content-Type"] = "application/json" + + sorted_parts = sorted(parts, key=lambda p: p.part_number) + request_body = json.dumps( + { + "parts": [ + {"part_number": p.part_number, "etag": p.etag} for p in sorted_parts + ] + } + ).encode("utf-8") + + with measure_storage_operation( + self._session._metrics_backend, + "multipart.complete", + self._session._usecase.name, + ): + response = self._session._pool.request( + "POST", + url, + body=request_body, + headers=headers, + preload_content=True, + decode_content=True, + ) + raise_for_status(response) + + raw = (response.data or b"").decode("utf-8").strip() + try: + data = json.loads(raw) + except json.JSONDecodeError: + raise ValueError("Failed to parse multipart complete response") + + if "error" in data: + raise MultipartCompleteError( + code=data["error"]["code"], + message=data["error"]["message"], + ) + + return data["key"] diff --git a/clients/python/tests/test_e2e.py b/clients/python/tests/test_e2e.py index 9632d489..7401fd69 100644 --- a/clients/python/tests/test_e2e.py +++ b/clients/python/tests/test_e2e.py @@ -7,6 +7,7 @@ import time from collections.abc import Generator from datetime import timedelta +from io import BytesIO from pathlib import Path import pytest @@ -14,8 +15,9 @@ import zstandard from objectstore_client import Client, Usecase from objectstore_client.auth import Permission, TokenGenerator -from objectstore_client.client import RequestError +from objectstore_client.errors import RequestError from objectstore_client.metadata import TimeToLive +from objectstore_client.multipart import CompletePart, MultipartCompleteError from objectstore_client.scope import Scope TEST_EDDSA_KID: str = "test_kid" @@ -27,6 +29,16 @@ ) +class UnrewindableStream(BytesIO): + """Read-only stream that cannot report or restore position.""" + + def seek(self, offset: int, whence: int = 0) -> int: + raise OSError("stream is not seekable") + + def tell(self) -> int: + raise OSError("stream does not expose a stable position") + + class TestTokenGenerator: _instance: TokenGenerator | None = None @@ -342,3 +354,299 @@ def test_connect_timeout() -> None: with pytest.raises(urllib3.exceptions.MaxRetryError): session.put(b"test data", compression="zstd") + + +def test_multipart_full_cycle_uncompressed(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload(key="mp-uncompressed") + assert upload.key == "mp-uncompressed" + assert upload.upload_id + + part1 = upload.put_part(b"hello ", part_number=1, content_length=6) + part2 = upload.put_part(b"world!", part_number=2, content_length=6) + + final_key = upload.complete([part1, part2]) + assert final_key == "mp-uncompressed" + + retrieved = session.get(final_key, decompress=False) + assert retrieved.payload.read() == b"hello world!" + + +def test_multipart_full_cycle_compressed(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload( + key="mp-compressed", + compression="zstd", + ) + + cctx = zstandard.ZstdCompressor() + compressed_part1 = cctx.compress(b"hello ") + compressed_part2 = cctx.compress(b"world!") + + part1 = upload.put_part( + compressed_part1, part_number=1, content_length=len(compressed_part1) + ) + part2 = upload.put_part( + compressed_part2, part_number=2, content_length=len(compressed_part2) + ) + + final_key = upload.complete([part1, part2]) + + # Verify raw compressed round-trip + retrieved = session.get(final_key, decompress=False) + assert retrieved.metadata.compression == "zstd" + raw = retrieved.payload.read() + assert raw == compressed_part1 + compressed_part2 + + # Verify transparent decompression + retrieved = session.get(final_key) + assert retrieved.metadata.compression is None + assert retrieved.payload.read() == b"hello world!" + + +def test_multipart_streaming_part_upload_uncompressed(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload(key="mp-streaming-uncompressed") + + part1_payload = b"hello " + part2_payload = b"world!" + part1 = upload.put_part( + UnrewindableStream(part1_payload), + part_number=1, + content_length=len(part1_payload), + ) + part2 = upload.put_part( + UnrewindableStream(part2_payload), + part_number=2, + content_length=len(part2_payload), + ) + + final_key = upload.complete([part1, part2]) + + retrieved = session.get(final_key) + assert retrieved.payload.read() == b"hello world!" + + +def test_multipart_streaming_part_upload_compressed(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload( + key="mp-streaming-compressed", + compression="zstd", + ) + + cctx = zstandard.ZstdCompressor() + compressed_part1 = cctx.compress(b"hello ") + compressed_part2 = cctx.compress(b"world!") + + part1 = upload.put_part( + UnrewindableStream(compressed_part1), + part_number=1, + content_length=len(compressed_part1), + ) + part2 = upload.put_part( + UnrewindableStream(compressed_part2), + part_number=2, + content_length=len(compressed_part2), + ) + + final_key = upload.complete([part1, part2]) + + retrieved = session.get(final_key, decompress=False) + assert retrieved.metadata.compression == "zstd" + assert retrieved.payload.read() == compressed_part1 + compressed_part2 + + retrieved = session.get(final_key) + assert retrieved.metadata.compression is None + assert retrieved.payload.read() == b"hello world!" + + +def test_multipart_server_generated_key(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload() + assert upload.key + + part = upload.put_part(b"data", part_number=1, content_length=4) + final_key = upload.complete([part]) + assert final_key + + retrieved = session.get(final_key) + assert retrieved.payload.read() == b"data" + + +def test_multipart_list_parts(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload(key="mp-list-parts") + + upload.put_part(b"part-two", part_number=2, content_length=8) + upload.put_part(b"part-one", part_number=1, content_length=8) + + parts = upload.list_parts() + assert len(parts) == 2 + + p1 = next(p for p in parts if p.part_number == 1) + p2 = next(p for p in parts if p.part_number == 2) + assert p1.size == 8 + assert p2.size == 8 + + upload.abort() + + +def test_multipart_abort(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload(key="mp-abort") + upload.put_part(b"some data", part_number=1, content_length=9) + upload.abort() + + +def test_multipart_metadata_preserved(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload( + key="mp-metadata", + content_type="text/plain", + origin="203.0.113.42", + metadata={"my-key": "my-value"}, + ) + + part = upload.put_part(b"payload", part_number=1, content_length=7) + final_key = upload.complete([part]) + + retrieved = session.get(final_key) + assert retrieved.metadata.content_type == "text/plain" + assert retrieved.metadata.origin == "203.0.113.42" + assert retrieved.metadata.custom.get("my-key") == "my-value" + + +def test_multipart_complete_with_bad_etag(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload(key="mp-bad-etag") + upload.put_part(b"real data", part_number=1, content_length=9) + + with pytest.raises(MultipartCompleteError) as exc_info: + upload.complete([CompletePart(part_number=1, etag="bogus-etag")]) + + assert exc_info.value.code + assert exc_info.value.status == 200 + + +def test_multipart_resume(server_url: str) -> None: + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload(key="mp-resume") + saved_key = upload.key + saved_upload_id = upload.upload_id + + upload.put_part(b"first", part_number=1, content_length=5) + + # Simulate resuming from saved state + resumed = session.resume_multipart_upload(saved_key, saved_upload_id) + assert resumed.key == saved_key + assert resumed.upload_id == saved_upload_id + + resumed.put_part(b"second", part_number=2, content_length=6) + + existing = resumed.list_parts() + assert len(existing) == 2 + + final_key = resumed.complete(existing) + + retrieved = session.get(final_key) + assert retrieved.payload.read() == b"firstsecond" + + +def test_multipart_concurrent_part_uploads(server_url: str) -> None: + from concurrent.futures import ThreadPoolExecutor + + client = Client(server_url, token=TestTokenGenerator.get()) + usecase = Usecase( + "test-usecase", + compression="none", + expiration_policy=TimeToLive(timedelta(days=1)), + ) + session = client.session(usecase, org=42, project=1337) + + upload = session.initiate_multipart_upload(key="mp-concurrent") + + chunks = [f"chunk-{i}".encode() for i in range(8)] + + def put_part(part_number: int, data: bytes) -> CompletePart: + return upload.put_part(data, part_number=part_number, content_length=len(data)) + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [ + executor.submit(put_part, i + 1, chunk) for i, chunk in enumerate(chunks) + ] + parts = [f.result() for f in futures] + + final_key = upload.complete(parts) + + retrieved = session.get(final_key) + assert retrieved.payload.read() == b"".join(chunks) diff --git a/clients/python/tests/test_multipart.py b/clients/python/tests/test_multipart.py new file mode 100644 index 00000000..1b3966dd --- /dev/null +++ b/clients/python/tests/test_multipart.py @@ -0,0 +1,59 @@ +import json +from typing import Any + +import pytest +from objectstore_client import Client, Usecase +from objectstore_client.errors import RequestError +from objectstore_client.multipart import CompletePart, MultipartUpload + + +class FakeResponse: + def __init__( + self, + status: int, + *, + data: bytes = b"", + json_data: dict[str, Any] | None = None, + ) -> None: + self.status = status + self.data = data + self._json_data = json_data + + def read(self) -> bytes: + return self.data + + def json(self) -> dict[str, Any]: + if self._json_data is not None: + return self._json_data + return json.loads(self.data.decode("utf-8")) + + +def test_upload_part_validates_bytes_content_length() -> None: + client = Client("http://127.0.0.1:8888") + session = client.session(Usecase("testing"), org=1) + upload = MultipartUpload(session, "key", "upload-id") + + with pytest.raises(ValueError, match="content_length must match"): + upload.put_part(b"payload", part_number=1, content_length=1) + + +def test_multipart_complete_raises_http_errors_before_parsing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = Client("http://127.0.0.1:8888") + session = client.session(Usecase("testing"), org=1) + upload = MultipartUpload(session, "key", "upload-id") + + monkeypatch.setattr( + session._pool, + "request", + lambda *args, **kwargs: FakeResponse( + 403, data=b'{"detail":"missing or expired auth"}' + ), + ) + + with pytest.raises(RequestError) as exc_info: + upload.complete([CompletePart(part_number=1, etag="etag")]) + + assert exc_info.value.status == 403 + assert exc_info.value.response == '{"detail":"missing or expired auth"}'