diff --git a/Cargo.lock b/Cargo.lock index 67445a39..0ca8c54b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2566,6 +2566,7 @@ name = "objectstore-client" version = "0.1.9" dependencies = [ "async-compression", + "base64", "bytes", "futures-util", "infer", @@ -2702,6 +2703,7 @@ version = "0.1.9" dependencies = [ "http 1.4.0", "humantime", + "humantime-serde", "insta", "mediatype", "serde", diff --git a/clients/rust/Cargo.toml b/clients/rust/Cargo.toml index 79bdee5a..6d3031d3 100644 --- a/clients/rust/Cargo.toml +++ b/clients/rust/Cargo.toml @@ -12,6 +12,7 @@ publish = true [dependencies] async-compression = { version = "0.4.27", features = ["tokio", "zstd"] } +base64 = { version = "0.22.1", optional = true } percent-encoding = { workspace = true } bytes = { workspace = true } futures-util = { workspace = true } @@ -39,6 +40,13 @@ zstd = "0.13.3" [features] default = ["native-tls", "hickory-dns"] +multipart = ["dep:base64"] + rustls = ["reqwest/rustls"] native-tls = ["reqwest/native-tls"] hickory-dns = ["reqwest/hickory-dns"] + +[[test]] +name = "multipart" +path = "tests/multipart.rs" +required-features = ["multipart"] diff --git a/clients/rust/README.md b/clients/rust/README.md index cf8dde88..67ff98fa 100644 --- a/clients/rust/README.md +++ b/clients/rust/README.md @@ -128,6 +128,89 @@ session.put("payload") .send().await?; ``` +### Multipart Upload API + +For large objects, use multipart uploads to upload parts concurrently with bounded +parallelism. + +**Important:** unlike single-object uploads, multipart uploads do **not** auto-compress. +The caller must pre-compress each part according to the compression set as part of the metadata +when initiating the upload. + +```rust,ignore +use futures_util::StreamExt as _; +use futures_util::stream; +use objectstore_client::Compression; + +let upload = session + .initiate_multipart_upload() + .key("my-large-object") + .compression(Compression::Zstd) + .send() + .await?; + +let parts: Vec<(Vec, u32)> = vec![ + (zstd::encode_all(&part1_data[..], 0)?, 1), + (zstd::encode_all(&part2_data[..], 0)?, 2), +]; + +let results: Vec<_> = stream::iter( + parts + .into_iter() + .map(|(data, part_number)| upload.put(data, part_number, None)), +) +.buffer_unordered(8) +.collect() +.await; + +let mut done = Vec::new(); +let mut errors = Vec::new(); +for result in results { + match result { + Ok(part) => done.push(part), + Err(e) => errors.push(e), + } +} + +if !errors.is_empty() { + // reupload failed parts... +} + +let key = upload.complete(done).await?; +// or +upload.abort().await?; +``` + +You can also resume an in-progress multipart upload, e.g. after a process restart. + +```rust,ignore +use futures_util::{StreamExt as _, TryStreamExt as _}; +use futures_util::stream; +use objectstore_client::CompletePart; + +let upload = session.resume_multipart_upload("my-large-object", saved_upload_id)?; + +let existing = upload.list_parts().await?; +let total_parts = 10; +let uploaded: Vec = existing.iter().map(|p| p.part_number.get()).collect(); +let missing: Vec = (1..=total_parts) + .filter(|n| !uploaded.contains(n)) + .collect(); + +let mut done: Vec<_> = stream::iter( + missing + .into_iter() + .map(|part_number| upload.put(get_part_data(part_number), part_number, None)), +) +.buffer_unordered(8) +.try_collect() +.await?; + +done.extend(existing.into_iter().map(CompletePart::from)); + +let key = upload.complete(done).await?; +``` + ### Many API The Many API allows you to enqueue multiple requests that the client can execute using Objectstore's batch endpoint, minimizing network overhead. diff --git a/clients/rust/src/client.rs b/clients/rust/src/client.rs index 2d4eced4..7a8848f9 100644 --- a/clients/rust/src/client.rs +++ b/clients/rust/src/client.rs @@ -151,7 +151,7 @@ impl ClientBuilder { #[derive(Debug, Clone)] pub struct Usecase { name: Arc, - compression: Compression, + compression: Option, expiration_policy: ExpirationPolicy, } @@ -160,7 +160,7 @@ impl Usecase { pub fn new(name: &str) -> Self { Self { name: name.into(), - compression: Compression::Zstd, + compression: Some(Compression::Zstd), expiration_policy: Default::default(), } } @@ -173,7 +173,7 @@ impl Usecase { /// Returns the compression algorithm to use for operations within this usecase. #[inline] - pub fn compression(&self) -> Compression { + pub fn compression(&self) -> Option { self.compression } @@ -181,10 +181,10 @@ impl Usecase { /// /// It's still possible to override this default on each operation's builder. /// - /// By default, [`Compression::Zstd`] is used. - pub fn with_compression(self, compression: Compression) -> Self { + /// By default, [`Compression::Zstd`] is used. Pass [`None`] to disable compression. + pub fn with_compression(self, compression: impl Into>) -> Self { Self { - compression, + compression: compression.into(), ..self } } @@ -464,6 +464,41 @@ impl Session { url } + #[cfg(feature = "multipart")] + fn multipart_url( + &self, + suffix: Option<&'static str>, + object_key: Option<&str>, + query_pairs: Option>, + ) -> Url { + let mut url = self.client.service_url.clone(); + + // `path_segments_mut` can only error if the url is cannot-be-a-base, + // and we check that in `ClientBuilder::new`, therefore this will never panic. + let mut segments = url.path_segments_mut().unwrap(); + segments + .push("v1") + .push(match suffix { + Some("parts") => "objects:multipart:parts", + Some("complete") => "objects:multipart:complete", + _ => "objects:multipart", + }) + .push(&self.scope.usecase.name) + .push(&self.scope.scopes.as_api_path().to_string()); + if let Some(object_key) = object_key.filter(|key| !key.is_empty()) { + segments.extend(object_key.split("/")); + } + drop(segments); + if let Some(query_pairs) = query_pairs { + let mut pairs = url.query_pairs_mut(); + for (key, value) in query_pairs { + pairs.append_pair(key, &value); + } + } + + url + } + fn prepare_builder(&self, mut builder: RequestBuilder) -> crate::Result { if let Some(token) = self.mint_token()? { builder = builder.header("x-os-auth", format!("Bearer {token}")); @@ -493,6 +528,19 @@ impl Session { let builder = self.client.reqwest.post(url); self.prepare_builder(builder) } + + #[cfg(feature = "multipart")] + pub(crate) fn multipart_request( + &self, + method: reqwest::Method, + action: Option<&'static str>, + object_key: Option<&str>, + query_pairs: Option>, + ) -> crate::Result { + let url = self.multipart_url(action, object_key, query_pairs); + let builder = self.client.reqwest.request(method, url); + self.prepare_builder(builder) + } } #[cfg(test)] diff --git a/clients/rust/src/error.rs b/clients/rust/src/error.rs index 098bb79a..341ea47d 100644 --- a/clients/rust/src/error.rs +++ b/clients/rust/src/error.rs @@ -44,6 +44,20 @@ pub enum Error { /// The error message. message: String, }, + /// Error when part number validation fails (must be >= 1). + #[error("invalid part number: {0}")] + InvalidPartNumber(u32), + /// Error when upload ID validation fails. + #[error(transparent)] + InvalidUploadId(#[from] objectstore_types::multipart::InvalidUploadId), + /// Error returned when attempting to complete a multipart upload. + #[error("multipart complete failed ({code}): {message}")] + MultipartComplete { + /// The error code or kind. + code: String, + /// The error message. + message: String, + }, } /// A convenience alias that defaults our [`Error`] type. diff --git a/clients/rust/src/get.rs b/clients/rust/src/get.rs index cdddc096..5c4cf446 100644 --- a/clients/rust/src/get.rs +++ b/clients/rust/src/get.rs @@ -128,7 +128,12 @@ pub(crate) fn maybe_decompress( match (metadata.compression, decompress && !encoding_accepted) { (Some(Compression::Zstd), true) => { metadata.compression = None; - ReaderStream::new(ZstdDecoder::new(StreamReader::new(stream))).boxed() + let mut decoder = ZstdDecoder::new(StreamReader::new(stream)); + // Multipart uploads with compression, when each part is compressed individually, + // will consist of multiple concatenated zstd frames. + // This allows the client to handle automatic decompression for these objects transparently. + decoder.multiple_members(true); + ReaderStream::new(decoder).boxed() } _ => stream, } @@ -230,4 +235,22 @@ mod tests { assert_eq!(collect(out).await, payload); assert_eq!(metadata.compression, None); } + + #[tokio::test] + async fn zstd_concatenated_frames_decompress() { + let payload1 = b"hello "; + let payload2 = b"world"; + let compressed1 = collect(compressed_zstd_stream(payload1)).await; + let compressed2 = collect(compressed_zstd_stream(payload2)).await; + let stream = futures_util::stream::iter([ + Ok::<_, std::io::Error>(bytes::Bytes::from(compressed1)), + Ok::<_, std::io::Error>(bytes::Bytes::from(compressed2)), + ]) + .boxed(); + + let mut metadata = zstd_metadata(); + let out = maybe_decompress(stream, &mut metadata, true, &[]); + assert_eq!(collect(out).await, b"hello world"); + assert_eq!(metadata.compression, None); + } } diff --git a/clients/rust/src/lib.rs b/clients/rust/src/lib.rs index 37758eff..3d497f18 100644 --- a/clients/rust/src/lib.rs +++ b/clients/rust/src/lib.rs @@ -10,6 +10,8 @@ mod get; mod head; mod key; mod many; +#[cfg(feature = "multipart")] +mod multipart; mod put; pub mod utils; @@ -23,4 +25,6 @@ pub use get::*; pub use head::*; pub use key::*; pub use many::*; +#[cfg(feature = "multipart")] +pub use multipart::*; pub use put::*; diff --git a/clients/rust/src/multipart.rs b/clients/rust/src/multipart.rs new file mode 100644 index 00000000..e5c6c690 --- /dev/null +++ b/clients/rust/src/multipart.rs @@ -0,0 +1,388 @@ +use std::borrow::Cow; +use std::collections::BTreeMap; + +use base64::Engine as _; +use bytes::Bytes; +use futures_util::StreamExt as _; +use objectstore_types::metadata::Metadata; +use objectstore_types::multipart::{ + CompleteErrorDetail, CompleteRequest, CompleteSuccessResponse, InitiateResponse, + ListPartsResponse, UploadPartResponse, +}; +use reqwest::Body; +use serde::Deserialize; +use tokio::io::AsyncRead; +use tokio_util::io::ReaderStream; + +use crate::{ClientStream, ObjectKey, Session}; + +pub use objectstore_types::multipart::CompletePart; +pub use objectstore_types::multipart::ETag; +pub use objectstore_types::multipart::PartInfo; +pub use objectstore_types::multipart::PartNumber; +pub use objectstore_types::multipart::UploadId; + +#[derive(Deserialize)] +#[serde(untagged)] +enum CompleteResponse { + Error { error: CompleteErrorDetail }, + Success(CompleteSuccessResponse), +} + +impl Session { + /// Creates a builder for initiating a multipart upload. + /// + /// The returned [`InitiateMultipartBuilder`] inherits the session's default compression + /// and expiration settings. + /// + /// IMPORTANT: unlike single-object uploads, the client does not automatically compress the + /// contents of [`MultipartUpload::put`]/[`MultipartUpload::put_stream`] based on the + /// configured `compression`. + /// The caller is responsible to compress the payload in accordance with the configured + /// `compression`. + /// That's because we require `content_length` on each part to be the length of the compressed + /// content, which we wouldn't be able to know beforehand if `objectstore_client` automatically + /// compressed payloads on the fly. + pub fn initiate_multipart_upload(&self) -> InitiateMultipartBuilder { + let metadata = Metadata { + expiration_policy: self.scope.usecase().expiration_policy(), + compression: self.scope.usecase().compression(), + ..Default::default() + }; + + InitiateMultipartBuilder { + session: self.clone(), + metadata, + key: None, + } + } + + /// Resumes an existing multipart upload from its key and upload ID. + /// + /// This reconstructs a [`MultipartUpload`] handle from previously obtained identifiers, and + /// doesn't make any network calls. + /// Use this to resume an upload after a process restart or to continue an upload initiated elsewhere. + pub fn resume_multipart_upload( + &self, + key: impl Into, + upload_id: impl Into, + ) -> crate::Result { + Ok(MultipartUpload { + session: self.clone(), + key: key.into(), + upload_id: UploadId::new(upload_id.into())?, + }) + } +} + +/// A builder for initiating a multipart upload. +#[derive(Debug)] +pub struct InitiateMultipartBuilder { + session: Session, + metadata: Metadata, + key: Option, +} + +impl InitiateMultipartBuilder { + /// Sets an explicit object key. + /// + /// If a key is specified, the object will be stored under that key. Otherwise, the Objectstore + /// server will automatically assign a random key, which is then returned from this request. + pub fn key(mut self, key: impl Into) -> Self { + self.key = Some(key.into()).filter(|k| !k.is_empty()); + self + } + + /// Sets the compression algorithm recorded in this object's metadata. + /// + /// IMPORTANT: unlike single-object uploads, the client does not automatically compress the + /// contents of [`MultipartUpload::put`]/[`MultipartUpload::put_stream`] based on the + /// configured `compression`. + /// The caller is responsible to compress the payload in accordance with the configured + /// `compression`. + /// + /// By default, the compression algorithm set on this Session's Usecase is used. + pub fn compression(mut self, compression: impl Into>) -> Self { + self.metadata.compression = compression.into(); + self + } + + /// Sets the expiration policy of the object to be uploaded. + /// + /// By default, the expiration policy set on this Session's Usecase is used. + pub fn expiration_policy(mut self, expiration_policy: crate::ExpirationPolicy) -> Self { + self.metadata.expiration_policy = expiration_policy; + self + } + + /// Sets the content type of the object to be uploaded. + /// + /// You can use the utility function [`crate::utils::guess_mime_type`] to attempt to guess a + /// `content_type` based on magic bytes. + pub fn content_type(mut self, content_type: impl Into>) -> Self { + self.metadata.content_type = content_type.into(); + self + } + + /// Sets the origin of the object, typically the IP address of the original source. + /// + /// This is an optional but encouraged field that tracks where the payload was + /// originally obtained from. For example, the IP address of the Sentry SDK or CLI + /// that uploaded the data. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(session: objectstore_client::Session) { + /// session.initiate_multipart_upload() + /// .origin("203.0.113.42") + /// .send() + /// .await + /// .unwrap(); + /// # } + /// ``` + pub fn origin(mut self, origin: impl Into) -> Self { + self.metadata.origin = Some(origin.into()); + self + } + + /// Sets the custom metadata to the provided map. + /// + /// It will clear any previously set metadata. + pub fn set_metadata(mut self, metadata: impl Into>) -> Self { + self.metadata.custom = metadata.into(); + self + } + + /// Appends the `key`/`value` to the custom metadata of this object. + pub fn append_metadata(mut self, key: impl Into, value: impl Into) -> Self { + self.metadata.custom.insert(key.into(), value.into()); + self + } + + /// Sends the initiate request and returns a [`MultipartUpload`] handle. + pub async fn send(self) -> crate::Result { + let method = match self.key { + Some(_) => reqwest::Method::PUT, + None => reqwest::Method::POST, + }; + + let mut builder = + self.session + .multipart_request(method, None, self.key.as_deref(), None)?; + + builder = builder.headers(self.metadata.to_headers("")?); + + let response: InitiateResponse = builder.send().await?.error_for_status()?.json().await?; + + Ok(MultipartUpload { + session: self.session, + key: response.key, + upload_id: response.upload_id, + }) + } +} + +/// Represents an ongoing Multipart Upload, tied to a specific [`Session`] and [`UploadId`]. +/// +/// Create a Multipart Upload handle using [`Session::initiate_multipart_upload`] or [`Session::resume_multipart_upload`]. +#[derive(Debug)] +pub struct MultipartUpload { + session: Session, + key: String, + upload_id: UploadId, +} + +impl MultipartUpload { + /// Returns the upload session identifier. + pub fn upload_id(&self) -> &UploadId { + &self.upload_id + } + + /// Returns the key of the object that this upload will create. + pub fn key(&self) -> &ObjectKey { + &self.key + } + + /// Uploads a part using a [`Bytes`]-like payload. + /// + /// IMPORTANT: unlike single-object uploads, the client does not automatically compress + /// contents based on this upload's `Metadata::compression`. + /// The caller is responsible to compress the payload in accordance with the `compression`, + /// and, optionally, to pass the `content_md5` of the compressed payload. + pub async fn put( + &self, + body: impl Into, + part_number: u32, + content_md5: Option<&[u8; 16]>, + ) -> crate::Result { + let bytes = body.into(); + let content_length = bytes.len() as u64; + self.upload_part(bytes.into(), part_number, content_length, content_md5) + .await + } + + /// Uploads a part using a streaming payload. + /// + /// IMPORTANT: unlike single-object uploads, the client does not automatically compress + /// contents based on this upload's `Metadata::compression`. + /// The caller is responsible to compress the payload in accordance with the `compression`, + /// and to pass the `content_length` and, optionally, `content_md5` of the compressed payload. + pub async fn put_stream( + &self, + stream: ClientStream, + part_number: u32, + content_length: u64, + content_md5: Option<&[u8; 16]>, + ) -> crate::Result { + self.upload_part( + Body::wrap_stream(stream), + part_number, + content_length, + content_md5, + ) + .await + } + + /// Uploads a part from an [`AsyncRead`] source. + /// + /// IMPORTANT: unlike single-object uploads, the client does not automatically compress + /// contents based on this upload's `Metadata::compression`. + /// The caller is responsible to compress the payload in accordance with the `compression`, + /// and to pass the `content_length` and, optionally, `content_md5` of the compressed payload. + pub async fn put_read( + &self, + reader: R, + part_number: u32, + content_length: u64, + content_md5: Option<&[u8; 16]>, + ) -> crate::Result + where + R: AsyncRead + Send + Sync + 'static, + { + let stream = ReaderStream::new(reader).boxed(); + self.put_stream(stream, part_number, content_length, content_md5) + .await + } + + async fn upload_part( + &self, + body: Body, + part_number: u32, + content_length: u64, + content_md5: Option<&[u8; 16]>, + ) -> crate::Result { + let part_number = + PartNumber::new(part_number).ok_or(crate::Error::InvalidPartNumber(part_number))?; + + let mut builder = self + .session + .multipart_request( + reqwest::Method::PUT, + Some("parts"), + Some(&self.key), + Some(vec![ + ("upload_id", self.upload_id.to_string()), + ("part_number", part_number.to_string()), + ]), + )? + .header(reqwest::header::CONTENT_LENGTH, content_length) + .body(body); + + if let Some(md5) = content_md5 { + let encoded = base64::engine::general_purpose::STANDARD.encode(md5); + builder = builder.header("content-md5", encoded); + } + + let response: UploadPartResponse = builder.send().await?.error_for_status()?.json().await?; + Ok(CompletePart { + part_number, + etag: response.etag, + }) + } + + /// Lists all parts that have been uploaded for this multipart upload. + pub async fn list_parts(&self) -> crate::Result> { + let mut all_parts = Vec::new(); + let mut marker = None; + + loop { + let page = self.list_parts_page(None, marker).await?; + all_parts.extend(page.parts); + + if !page.is_truncated { + return Ok(all_parts); + } + marker = page.next_part_number_marker; + if marker.is_none() { + return Err(crate::Error::MalformedResponse( + "server returned is_truncated=true but no next_part_number_marker. Please report a bug.".into(), + )); + } + } + } + + async fn list_parts_page( + &self, + max_parts: Option, + part_number_marker: Option, + ) -> crate::Result { + let mut params: Vec<(&str, String)> = vec![("upload_id", self.upload_id.to_string())]; + if let Some(max) = max_parts { + params.push(("max_parts", max.to_string())); + } + if let Some(marker) = part_number_marker { + params.push(("part_number_marker", marker.to_string())); + } + + let builder = self.session.multipart_request( + reqwest::Method::GET, + Some("parts"), + Some(&self.key), + Some(params), + )?; + + let response: ListPartsResponse = builder.send().await?.error_for_status()?.json().await?; + Ok(response) + } + + /// Aborts this multipart upload. + pub async fn abort(self) -> crate::Result<()> { + let builder = self.session.multipart_request( + reqwest::Method::DELETE, + None, + Some(&self.key), + Some(vec![("upload_id", self.upload_id.to_string())]), + )?; + builder.send().await?.error_for_status()?; + Ok(()) + } + + /// Completes the multipart upload, assembling all parts into the final object. + pub async fn complete( + self, + parts: impl IntoIterator, + ) -> crate::Result { + let mut parts: Vec<_> = parts.into_iter().collect(); + parts.sort_by_key(|p| p.part_number); + + let builder = self + .session + .multipart_request( + reqwest::Method::POST, + Some("complete"), + Some(&self.key), + Some(vec![("upload_id", self.upload_id.to_string())]), + )? + .json(&CompleteRequest { parts }); + + let response = builder.send().await?.error_for_status()?; + match response.json::().await? { + CompleteResponse::Success(s) => Ok(s.key), + CompleteResponse::Error { error } => Err(crate::Error::MultipartComplete { + code: error.code, + message: error.message, + }), + } + } +} diff --git a/clients/rust/src/put.rs b/clients/rust/src/put.rs index a1abc50e..465039af 100644 --- a/clients/rust/src/put.rs +++ b/clients/rust/src/put.rs @@ -41,7 +41,7 @@ impl Session { fn put_body(&self, body: PutBody) -> PutBuilder { let metadata = Metadata { expiration_policy: self.scope.usecase().expiration_policy(), - compression: Some(self.scope.usecase().compression()), + compression: self.scope.usecase().compression(), ..Default::default() }; diff --git a/clients/rust/tests/common/mod.rs b/clients/rust/tests/common/mod.rs new file mode 100644 index 00000000..c4e4513d --- /dev/null +++ b/clients/rust/tests/common/mod.rs @@ -0,0 +1,37 @@ +#![allow(dead_code)] + +use std::sync::LazyLock; + +use objectstore_client::{Client, SecretKey, Session, TokenGenerator, Usecase}; +use objectstore_test::server::{TEST_EDDSA_KID, TEST_EDDSA_PRIVKEY_PATH, TestServer, config}; + +pub static TEST_EDDSA_PRIVKEY: LazyLock = + LazyLock::new(|| std::fs::read_to_string(&*TEST_EDDSA_PRIVKEY_PATH).unwrap()); + +pub async fn test_server() -> TestServer { + TestServer::with_config(config::Config { + auth: config::AuthZ { + enforce: true, + ..Default::default() + }, + ..Default::default() + }) + .await +} + +pub fn test_token_generator() -> TokenGenerator { + TokenGenerator::new(SecretKey { + kid: TEST_EDDSA_KID.into(), + secret_key: TEST_EDDSA_PRIVKEY.clone(), + }) + .unwrap() +} + +pub fn test_session(server: &TestServer) -> Session { + let client = Client::builder(server.url("/")) + .token(test_token_generator()) + .build() + .unwrap(); + let usecase = Usecase::new("usecase"); + client.session(usecase.for_organization(12345)).unwrap() +} diff --git a/clients/rust/tests/e2e.rs b/clients/rust/tests/e2e.rs index 88a90d80..c3db3a03 100644 --- a/clients/rust/tests/e2e.rs +++ b/clients/rust/tests/e2e.rs @@ -1,20 +1,17 @@ +mod common; + use std::collections::{BTreeMap, HashSet}; use std::io::Write as _; -use std::sync::LazyLock; +use common::{TEST_EDDSA_PRIVKEY, test_server, test_token_generator}; use futures_util::StreamExt as _; use jsonwebtoken::{Algorithm, EncodingKey, Header, encode, get_current_timestamp}; -use objectstore_client::{ - Client, Error, OperationResult, Permission, SecretKey, TokenGenerator, Usecase, -}; -use objectstore_test::server::{TEST_EDDSA_KID, TEST_EDDSA_PRIVKEY_PATH, TestServer, config}; +use objectstore_client::{Client, Error, OperationResult, Permission, Usecase}; +use objectstore_test::server::TEST_EDDSA_KID; use objectstore_types::metadata::Compression; use reqwest::StatusCode; use serde::Serialize; -pub static TEST_EDDSA_PRIVKEY: LazyLock = - LazyLock::new(|| std::fs::read_to_string(&*TEST_EDDSA_PRIVKEY_PATH).unwrap()); - #[derive(Serialize)] struct JwtClaims { exp: u64, @@ -53,25 +50,6 @@ fn sign_static_token(usecase: &str, scopes: &[(&str, &str)]) -> String { encode(&header, &claims, &encoding_key).unwrap() } -async fn test_server() -> TestServer { - TestServer::with_config(config::Config { - auth: config::AuthZ { - enforce: true, - ..Default::default() - }, - ..Default::default() - }) - .await -} - -fn test_token_generator() -> TokenGenerator { - TokenGenerator::new(SecretKey { - kid: TEST_EDDSA_KID.into(), - secret_key: TEST_EDDSA_PRIVKEY.clone(), - }) - .unwrap() -} - #[tokio::test] async fn stores_uncompressed() { let server = test_server().await; diff --git a/clients/rust/tests/multipart.rs b/clients/rust/tests/multipart.rs new file mode 100644 index 00000000..08610656 --- /dev/null +++ b/clients/rust/tests/multipart.rs @@ -0,0 +1,264 @@ +//! End-to-end tests for the multipart upload client API. + +mod common; + +use common::{test_server, test_session}; +use futures_util::StreamExt as _; +use futures_util::stream; +use objectstore_client::{Client, CompletePart, Compression, Error, PartNumber, Usecase}; + +use crate::common::test_token_generator; + +#[tokio::test] +async fn test_full_upload_uncompressed() { + let server = test_server().await; + let client = Client::builder(server.url("/")) + .token(test_token_generator()) + .build() + .unwrap(); + let usecase = Usecase::new("usecase").with_compression(None); + + let session = client.session(usecase.for_organization(12345)).unwrap(); + + let upload = session + .initiate_multipart_upload() + .key("multipart-test-key") + .send() + .await + .unwrap(); + + assert_eq!(upload.key(), "multipart-test-key"); + assert!(!upload.upload_id().is_empty()); + + let parts_data: Vec<(&[u8], u32)> = vec![(b"hello ", 1), (b"world!", 2)]; + + let results: Vec<_> = stream::iter( + parts_data + .into_iter() + .map(|(data, part_number)| upload.put(data, part_number, None)), + ) + .buffer_unordered(2) + .collect() + .await; + + let mut parts = Vec::new(); + let mut errors = Vec::new(); + for result in results { + match result { + Ok(part) => parts.push(part), + Err(e) => errors.push(e), + } + } + assert!(errors.is_empty(), "part uploads failed: {errors:?}"); + + let key = upload.complete(parts).await.unwrap(); + + assert_eq!(key, "multipart-test-key"); + + let response = session + .get(&key) + .decompress(false) + .send() + .await + .unwrap() + .unwrap(); + assert_eq!(response.metadata.compression, None); + let payload = response.payload().await.unwrap(); + assert_eq!(payload, "hello world!"); +} + +#[tokio::test] +async fn test_full_upload_compressed() { + let server = test_server().await; + let session = test_session(&server); + + let upload = session + .initiate_multipart_upload() + .key("multipart-compressed-key") + .compression(Compression::Zstd) + .send() + .await + .unwrap(); + + let part1_data = b"hello "; + let part2_data = b"world!"; + + let parts_data: Vec<(Vec, u32)> = vec![ + (zstd::encode_all(&part1_data[..], 0).unwrap(), 1), + (zstd::encode_all(&part2_data[..], 0).unwrap(), 2), + ]; + + let results: Vec<_> = stream::iter( + parts_data + .into_iter() + .map(|(data, part_number)| upload.put(data, part_number, None)), + ) + .buffer_unordered(2) + .collect() + .await; + + let mut parts = Vec::new(); + let mut errors = Vec::new(); + for result in results { + match result { + Ok(part) => parts.push(part), + Err(e) => errors.push(e), + } + } + assert!(errors.is_empty(), "part uploads failed: {errors:?}"); + + let key = upload.complete(parts).await.unwrap(); + + let response = session + .get(&key) + .decompress(false) + .send() + .await + .unwrap() + .unwrap(); + assert_eq!(response.metadata.compression, Some(Compression::Zstd)); + + let mut expected = zstd::encode_all(&part1_data[..], 0).unwrap(); + expected.extend(zstd::encode_all(&part2_data[..], 0).unwrap()); + assert_eq!( + response.payload().await.unwrap().as_ref(), + expected.as_slice() + ); + + let response = session.get(&key).send().await.unwrap().unwrap(); + assert_eq!(response.metadata.compression, None); + assert_eq!(response.payload().await.unwrap(), "hello world!"); +} + +#[tokio::test] +async fn test_server_generated_key() { + let server = test_server().await; + let session = test_session(&server); + + let upload = session + .initiate_multipart_upload() + .compression(None) + .send() + .await + .unwrap(); + + assert!(!upload.key().is_empty()); + + let part = upload.put(b"data".as_slice(), 1, None).await.unwrap(); + + let key = upload.complete([part]).await.unwrap(); + + assert!(!key.is_empty()); + + let response = session.get(&key).send().await.unwrap().unwrap(); + assert_eq!(response.payload().await.unwrap(), "data"); +} + +#[tokio::test] +async fn test_list_parts() { + let server = test_server().await; + let session = test_session(&server); + + let upload = session + .initiate_multipart_upload() + .key("list-parts-key") + .compression(None) + .send() + .await + .unwrap(); + + upload.put(b"part-two".as_slice(), 2, None).await.unwrap(); + upload.put(b"part-one".as_slice(), 1, None).await.unwrap(); + + let parts = upload.list_parts().await.unwrap(); + assert_eq!(parts.len(), 2); + + let p1 = parts + .iter() + .find(|p| p.part_number.get() == 1) + .expect("missing part 1"); + let p2 = parts + .iter() + .find(|p| p.part_number.get() == 2) + .expect("missing part 2"); + assert_eq!(p1.size, 8); + assert_eq!(p2.size, 8); + + upload.abort().await.unwrap(); +} + +#[tokio::test] +async fn test_abort() { + let server = test_server().await; + let session = test_session(&server); + + let upload = session + .initiate_multipart_upload() + .key("abort-key") + .send() + .await + .unwrap(); + + upload.put(b"some data".as_slice(), 1, None).await.unwrap(); + upload.abort().await.unwrap(); +} + +#[tokio::test] +async fn test_metadata_preserved() { + let server = test_server().await; + let session = test_session(&server); + + let upload = session + .initiate_multipart_upload() + .key("metadata-key") + .compression(None) + .content_type("text/plain") + .origin("203.0.113.42") + .append_metadata("my-key".to_string(), "my-value".to_string()) + .send() + .await + .unwrap(); + + let part = upload.put(b"payload".as_slice(), 1, None).await.unwrap(); + + let key = upload.complete([part]).await.unwrap(); + + let response = session.get(&key).send().await.unwrap().unwrap(); + assert_eq!(response.metadata.content_type, "text/plain"); + assert_eq!(response.metadata.origin.as_deref(), Some("203.0.113.42")); + assert_eq!( + response.metadata.custom.get("my-key").map(String::as_str), + Some("my-value") + ); +} + +#[tokio::test] +async fn test_complete_with_bad_etag() { + let server = test_server().await; + let session = test_session(&server); + + let upload = session + .initiate_multipart_upload() + .key("bad-etag-key") + .compression(None) + .send() + .await + .unwrap(); + + upload.put(b"real data".as_slice(), 1, None).await.unwrap(); + + let result = upload + .complete(vec![CompletePart { + part_number: PartNumber::new(1).unwrap(), + etag: "bogus-etag".to_string(), + }]) + .await; + + match result { + Err(Error::MultipartComplete { code, message }) => { + assert!(!code.is_empty(), "error code should not be empty"); + assert!(!message.is_empty(), "error message should not be empty"); + } + other => panic!("expected MultipartComplete error, got: {other:?}"), + } +} diff --git a/objectstore-server/src/endpoints/common.rs b/objectstore-server/src/endpoints/common.rs index 365f8e62..ab6a8240 100644 --- a/objectstore-server/src/endpoints/common.rs +++ b/objectstore-server/src/endpoints/common.rs @@ -93,6 +93,7 @@ impl ApiError { ApiError::Service(ServiceError::Client(_)) => StatusCode::BAD_REQUEST, ApiError::Service(ServiceError::Metadata(_)) => StatusCode::BAD_REQUEST, + ApiError::Service(ServiceError::InvalidUploadId(_)) => StatusCode::BAD_REQUEST, ApiError::Service(ServiceError::AtCapacity) => StatusCode::TOO_MANY_REQUESTS, ApiError::Service(ServiceError::NotImplemented) => StatusCode::NOT_IMPLEMENTED, ApiError::Service(_) => { diff --git a/objectstore-server/src/endpoints/multipart.rs b/objectstore-server/src/endpoints/multipart.rs index a9a5f17b..c9d40d26 100644 --- a/objectstore-server/src/endpoints/multipart.rs +++ b/objectstore-server/src/endpoints/multipart.rs @@ -1,7 +1,11 @@ -use std::collections::BTreeMap; use std::convert::Infallible; use std::time::{Duration, SystemTime}; +use crate::auth::AuthAwareService; +use crate::endpoints::common::{ApiError, ApiResult}; +use crate::extractors::Xt; +use crate::extractors::body::MeteredBody; +use crate::state::ServiceState; use axum::body::Body; use axum::extract::{Query, State}; use axum::http::{HeaderMap, StatusCode}; @@ -16,13 +20,11 @@ use objectstore_service::error::Error as ServiceError; use objectstore_service::id::{ObjectContext, ObjectId}; use objectstore_service::multipart::{CompletedPart, PartNumber, UploadId}; use objectstore_types::metadata::Metadata; -use serde::{Deserialize, Serialize}; - -use crate::auth::AuthAwareService; -use crate::endpoints::common::{ApiError, ApiResult}; -use crate::extractors::Xt; -use crate::extractors::body::MeteredBody; -use crate::state::ServiceState; +use objectstore_types::multipart::{ + CompleteErrorDetail, CompleteErrorResponse, CompleteRequest, CompleteSuccessResponse, + InitiateResponse, ListPartsResponse, PartInfo, UploadPartResponse, +}; +use serde::Deserialize; pub fn router() -> Router { let initiate_no_key = routing::post(initiate_post); @@ -66,62 +68,6 @@ struct ListPartsQuery { part_number_marker: Option, } -// --- Request/Response types --- - -#[derive(Debug, Serialize)] -struct InitiateResponse { - key: String, - upload_id: UploadId, -} - -#[derive(Debug, Serialize)] -struct UploadPartResponse { - etag: String, -} - -#[derive(Debug, Serialize)] -struct PartInfo { - etag: String, - #[serde(with = "humantime_serde")] - last_modified: SystemTime, - size: u64, -} - -#[derive(Debug, Serialize)] -struct ListPartsResponse { - parts: BTreeMap, - is_truncated: bool, - #[serde(skip_serializing_if = "Option::is_none")] - next_part_number_marker: Option, -} - -#[derive(Debug, Deserialize)] -struct CompletePartRequest { - part_number: PartNumber, - etag: String, -} - -#[derive(Debug, Deserialize)] -struct CompleteRequest { - parts: Vec, -} - -#[derive(Debug, Serialize)] -struct CompleteSuccessResponse { - key: String, -} - -#[derive(Debug, Serialize)] -struct CompleteErrorDetail { - code: String, - message: String, -} - -#[derive(Debug, Serialize)] -struct CompleteErrorResponse { - error: CompleteErrorDetail, -} - // --- Handlers --- async fn initiate_put( @@ -219,13 +165,11 @@ async fn list_parts( let parts = response .parts .into_iter() - .map(|p| { - let info = PartInfo { - etag: p.etag, - last_modified: p.last_modified, - size: p.size, - }; - (p.part_number, info) + .map(|p| PartInfo { + part_number: p.part_number, + etag: p.etag, + last_modified: p.last_modified, + size: p.size, }) .collect(); diff --git a/objectstore-server/tests/multipart.rs b/objectstore-server/tests/multipart.rs index eaa70b24..0dbd28cd 100644 --- a/objectstore-server/tests/multipart.rs +++ b/objectstore-server/tests/multipart.rs @@ -3,48 +3,13 @@ use anyhow::Result; use objectstore_server::config::{AuthZ, Config}; use objectstore_test::server::TestServer; -use serde::Deserialize; +use objectstore_types::multipart::{ + CompleteErrorResponse, CompleteSuccessResponse, InitiateResponse, ListPartsResponse, + PartNumber, UploadPartResponse, +}; -#[derive(Debug, Deserialize)] -struct InitiateResponse { - key: String, - upload_id: String, -} - -#[derive(Debug, Deserialize)] -struct UploadPartResponse { - etag: String, -} - -#[derive(Debug, Deserialize)] -#[allow(dead_code)] -struct PartInfo { - etag: String, - last_modified: String, - size: u64, -} - -#[derive(Debug, Deserialize)] -struct ListPartsResponse { - parts: std::collections::BTreeMap, - is_truncated: bool, - next_part_number_marker: Option, -} - -#[derive(Debug, Deserialize)] -struct CompleteSuccessResponse { - key: String, -} - -#[derive(Debug, Deserialize)] -struct CompleteErrorDetail { - code: String, - message: String, -} - -#[derive(Debug, Deserialize)] -struct CompleteErrorResponse { - error: CompleteErrorDetail, +fn pn(n: u32) -> PartNumber { + PartNumber::new(n).unwrap() } async fn test_server() -> TestServer { @@ -192,10 +157,10 @@ async fn test_multipart_full_flow() -> Result<()> { assert_eq!(response.status(), reqwest::StatusCode::OK); let list: ListPartsResponse = response.json().await?; assert_eq!(list.parts.len(), 2); - assert!(list.parts.contains_key(&1)); - assert!(list.parts.contains_key(&2)); - assert_eq!(list.parts[&1].size, part1_data.len() as u64); - assert_eq!(list.parts[&2].size, part2_data.len() as u64); + assert_eq!(list.parts[0].part_number, pn(1)); + assert_eq!(list.parts[1].part_number, pn(2)); + assert_eq!(list.parts[0].size, part1_data.len() as u64); + assert_eq!(list.parts[1].size, part2_data.len() as u64); assert!(!list.is_truncated); // 5. Complete @@ -433,8 +398,9 @@ async fn test_upload_part_overwrite() -> Result<()> { assert_eq!(response.status(), reqwest::StatusCode::OK); let list: ListPartsResponse = response.json().await?; assert_eq!(list.parts.len(), 1); - assert_eq!(list.parts[&1].etag, second_etag.etag); - assert_eq!(list.parts[&1].size, 6); + assert_eq!(list.parts[0].part_number, pn(1)); + assert_eq!(list.parts[0].etag, second_etag.etag); + assert_eq!(list.parts[0].size, 6); // 5. Complete with the overwritten part complete_and_assert( diff --git a/objectstore-service/src/backend/gcs.rs b/objectstore-service/src/backend/gcs.rs index 80174148..43cc45c1 100644 --- a/objectstore-service/src/backend/gcs.rs +++ b/objectstore-service/src/backend/gcs.rs @@ -733,7 +733,7 @@ impl TryFrom for InitiateMultipartResponse { type Error = crate::error::Error; fn try_from(r: XmlInitiateMultipartUploadResponse) -> crate::error::Result { - UploadId::new(r.upload_id) + Ok(UploadId::new(r.upload_id)?) } } diff --git a/objectstore-service/src/backend/tiered.rs b/objectstore-service/src/backend/tiered.rs index 7d63dd54..1385e01d 100644 --- a/objectstore-service/src/backend/tiered.rs +++ b/objectstore-service/src/backend/tiered.rs @@ -580,7 +580,9 @@ impl TryInto for TieredUploadId { fn try_into(self) -> Result { let json = serde_json::to_vec(&self).map_err(|e| Error::serde("encoding multipart token", e))?; - UploadId::new(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json)) + Ok(UploadId::new( + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json), + )?) } } diff --git a/objectstore-service/src/error.rs b/objectstore-service/src/error.rs index 27344abe..46e38e28 100644 --- a/objectstore-service/src/error.rs +++ b/objectstore-service/src/error.rs @@ -91,6 +91,10 @@ pub enum Error { /// The functionality is not implemented by this instance of the service. #[error("not implemented")] NotImplemented, + + /// Invalid upload ID (e.g. path traversal attempt). + #[error(transparent)] + InvalidUploadId(#[from] objectstore_types::multipart::InvalidUploadId), } impl Error { @@ -147,6 +151,7 @@ impl Error { Self::Dropped => Level::ERROR, Self::UnexpectedTombstone => Level::ERROR, Self::NotImplemented => Level::ERROR, + Self::InvalidUploadId(_) => Level::DEBUG, Self::Generic { .. } => Level::ERROR, } } diff --git a/objectstore-service/src/multipart.rs b/objectstore-service/src/multipart.rs index 1cbe465a..b06e4814 100644 --- a/objectstore-service/src/multipart.rs +++ b/objectstore-service/src/multipart.rs @@ -1,68 +1,8 @@ //! Shared types for Objectstore's multipart upload protocol. -use std::fmt; -use std::ops::Deref; -use std::path::{Component, Path}; use std::time::SystemTime; -use serde::{Deserialize, Deserializer, Serialize}; - -use crate::error::Error; - -/// Identifier for an in-progress multipart upload. -/// -/// Validated on construction: non-empty and free of path-traversal components -/// (`..`, leading `/`, etc.), so it is always safe to use as a single path segment. -#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize)] -#[serde(transparent)] -pub struct UploadId(String); - -impl UploadId { - /// Returns the upload ID as a string slice. - pub fn as_str(&self) -> &str { - &self.0 - } - - /// Creates a new `UploadId` after validating the input. - pub fn new(s: String) -> Result { - if s.is_empty() { - return Err(Error::generic("upload_id must not be empty")); - } - for component in Path::new(&s).components() { - if !matches!(component, Component::Normal(_)) { - return Err(Error::generic(format!("invalid upload_id: {s}"))); - } - } - Ok(Self(s)) - } -} - -impl Deref for UploadId { - type Target = str; - fn deref(&self) -> &str { - &self.0 - } -} - -impl fmt::Display for UploadId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl<'de> Deserialize<'de> for UploadId { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - Self::new(s).map_err(serde::de::Error::custom) - } -} -/// 1-indexed position of a part within its multipart upload. -pub type PartNumber = std::num::NonZeroU32; -/// Opaque per-part identifier returned by the backend after a successful part upload. -pub type ETag = String; +pub use objectstore_types::multipart::{ETag, InvalidUploadId, PartNumber, UploadId}; /// Description of one part in the response to /// [`MultipartUploadBackend::list_parts`](crate::backend::common::MultipartUploadBackend::list_parts). diff --git a/objectstore-types/Cargo.toml b/objectstore-types/Cargo.toml index c29e094e..bed43716 100644 --- a/objectstore-types/Cargo.toml +++ b/objectstore-types/Cargo.toml @@ -12,6 +12,7 @@ publish = true [dependencies] http = { workspace = true } humantime = { workspace = true } +humantime-serde = { workspace = true } mediatype = "0.21.0" serde = { workspace = true } thiserror = { workspace = true } diff --git a/objectstore-types/src/lib.rs b/objectstore-types/src/lib.rs index 2b7f9e28..13d2946b 100644 --- a/objectstore-types/src/lib.rs +++ b/objectstore-types/src/lib.rs @@ -32,4 +32,5 @@ pub mod auth; pub mod metadata; +pub mod multipart; pub mod scope; diff --git a/objectstore-types/src/multipart.rs b/objectstore-types/src/multipart.rs new file mode 100644 index 00000000..b56cf44e --- /dev/null +++ b/objectstore-types/src/multipart.rs @@ -0,0 +1,161 @@ +//! Types for the multipart upload protocol. + +use std::fmt; +use std::num::NonZeroU32; +use std::ops::Deref; +use std::path::{Component, Path}; +use std::time::SystemTime; + +use serde::{Deserialize, Deserializer, Serialize}; + +/// 1-indexed position of a part within its multipart upload. +pub type PartNumber = NonZeroU32; + +/// Opaque entity tag identifying a specific version of an uploaded part. +pub type ETag = String; + +/// Identifier for an in-progress multipart upload. +/// +/// Validated on construction: non-empty and free of path-traversal components +/// (`..`, leading `/`, etc.), so it is always safe to use as a single path segment. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize)] +#[serde(transparent)] +pub struct UploadId(String); + +/// Error returned when an [`UploadId`] fails validation. +#[derive(Debug, thiserror::Error)] +#[error("invalid upload_id: {0}")] +pub struct InvalidUploadId(String); + +impl UploadId { + /// Creates a new `UploadId` after validating the input. + pub fn new(s: String) -> Result { + if s.is_empty() { + return Err(InvalidUploadId("must not be empty".into())); + } + for component in Path::new(&s).components() { + if !matches!(component, Component::Normal(_)) { + return Err(InvalidUploadId(s)); + } + } + Ok(Self(s)) + } + + /// Returns the upload ID as a string slice. + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl Deref for UploadId { + type Target = str; + fn deref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for UploadId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl<'de> Deserialize<'de> for UploadId { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Self::new(s).map_err(serde::de::Error::custom) + } +} + +/// Response from initiating a multipart upload. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InitiateResponse { + /// The object key (server-generated or user-provided). + pub key: String, + /// The upload session identifier for subsequent requests. + pub upload_id: UploadId, +} + +/// Response from uploading a single part. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UploadPartResponse { + /// Opaque identifier of the uploaded part. + pub etag: ETag, +} + +/// Information about a single uploaded part, as returned by list-parts. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PartInfo { + /// The part number. + pub part_number: PartNumber, + /// Opaque identifier of the part. + pub etag: ETag, + /// When the part was last modified. + #[serde(with = "humantime_serde")] + pub last_modified: SystemTime, + /// Size of the part in bytes. + pub size: u64, +} + +impl From for CompletePart { + fn from(info: PartInfo) -> Self { + Self { + part_number: info.part_number, + etag: info.etag, + } + } +} + +/// Response from listing parts of a multipart upload. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ListPartsResponse { + /// Parts uploaded so far. + pub parts: Vec, + /// Whether the response was truncated. + pub is_truncated: bool, + /// Marker for the next page of results, if truncated. + #[serde(skip_serializing_if = "Option::is_none")] + pub next_part_number_marker: Option, +} + +/// A single part reference used in the complete request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletePart { + /// The part number. + pub part_number: PartNumber, + /// The etag returned when this part was uploaded. + pub etag: ETag, +} + +/// Request body for completing a multipart upload. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompleteRequest { + /// Ordered list of all parts that make up the object. + pub parts: Vec, +} + +/// Successful response from completing a multipart upload. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompleteSuccessResponse { + /// The object key. + pub key: String, +} + +/// Detail of an error that occurred during multipart completion. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompleteErrorDetail { + /// Error code. + pub code: String, + /// Human-readable error message. + pub message: String, +} + +/// Error response from completing a multipart upload. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompleteErrorResponse { + /// The error detail. + pub error: CompleteErrorDetail, +}