From d77f2c28a7cedb585746b762804097fa3da75deb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Lemaire-Giroud?= Date: Wed, 21 May 2025 09:28:32 +0200 Subject: [PATCH 1/6] fix: prevent leaking secrets --- src/client/mod.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index d5df651..dc5069e 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -135,7 +135,7 @@ pub trait RestClient: Debug { // ----------------------------------------------------------------------------- // ClientCredentials structure -#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)] +#[derive(Serialize, Deserialize, PartialEq, Eq, Clone)] #[serde(untagged)] pub enum Credentials { OAuth1 { @@ -160,6 +160,17 @@ pub enum Credentials { }, } +impl fmt::Debug for Credentials { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // NOTE: ensure secrets are not leaked in logs + match self { + Self::OAuth1 { .. } => f.write_str("OAuth1"), + Self::Basic { .. } => f.write_str("Basic"), + Self::Bearer { .. } => f.write_str("Bearer"), + } + } +} + impl Default for Credentials { fn default() -> Self { Self::OAuth1 { From 110a220948311c1ad1fbd0dc3b047c082e35832b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Lemaire-Giroud?= Date: Wed, 21 May 2025 09:32:04 +0200 Subject: [PATCH 2/6] chore: add SseErrorOf type alias --- src/client/sse.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/client/sse.rs b/src/client/sse.rs index 0174139..91f6816 100644 --- a/src/client/sse.rs +++ b/src/client/sse.rs @@ -30,13 +30,13 @@ use url::Url; use super::Request; -pub type SseResult = - Result, SseError<::Error, ::Err, ::Err>>; +pub type SseErrorOf = + SseError<::Error, ::Err, ::Err>; -pub type SseBuildResult = Result< - SseStream, - SseError<::Error, ::Err, ::Err>, ->; +pub type SseResult = Result, SseErrorOf>; + +pub type SseBuildResult = + Result, SseErrorOf>; /// Default initial capacity of the buffer of the [`SseStream`]. pub const DEFAULT_INITIAL_CAPACITY: usize = 512; From 561c375857089d8e80333b79b94fc9db2ecc8af7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Lemaire-Giroud?= Date: Wed, 21 May 2025 09:37:44 +0200 Subject: [PATCH 3/6] feat: let endpoint argument be any type that impl reqwest::IntoUrl (closes #53) * move request method to `RestClient` trait for better separation of concerns * move execute method to `Execute` trait for better type inference --- src/client/mod.rs | 176 +++++++++++++++++++++------------------------- src/client/sse.rs | 16 ++--- 2 files changed, 89 insertions(+), 103 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index dc5069e..5681921 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -19,11 +19,11 @@ use bytes::Buf; use crypto_common::InvalidLength; use hmac::{Hmac, Mac}; #[cfg(feature = "logging")] -use log::{Level, error, log_enabled, trace}; +use log::{error, trace}; #[cfg(feature = "metrics")] use prometheus::{CounterVec, opts, register_counter_vec}; use reqwest::{ - Method, StatusCode, + IntoUrl, Method, StatusCode, header::{self, HeaderValue}, }; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -71,21 +71,12 @@ static CLIENT_REQUEST_DURATION: LazyLock = LazyLock::new(|| { type HmacSha512 = Hmac; // ----------------------------------------------------------------------------- -// Request trait +// Execute trait -pub trait Request { +/// Execute HTTP requests. +pub trait Execute { type Error; - fn request( - &self, - method: &Method, - endpoint: &str, - payload: &T, - ) -> impl Future> + Send - where - T: ?Sized + Serialize + Debug + Send + Sync, - U: DeserializeOwned + Debug + Send + Sync; - fn execute( &self, request: reqwest::Request, @@ -95,16 +86,24 @@ pub trait Request { // ----------------------------------------------------------------------------- // RestClient trait -pub trait RestClient: Debug { - type Error; +pub trait RestClient: Execute { + fn request( + &self, + method: &Method, + endpoint: X, + payload: &T, + ) -> impl Future> + Send + where + T: ?Sized + Serialize + Debug + Send + Sync, + U: DeserializeOwned + Debug + Send + Sync; - fn get(&self, endpoint: &str) -> impl Future> + Send + fn get(&self, endpoint: X) -> impl Future> + Send where T: DeserializeOwned + Debug + Send + Sync; fn post( &self, - endpoint: &str, + endpoint: X, payload: &T, ) -> impl Future> + Send where @@ -113,7 +112,7 @@ pub trait RestClient: Debug { fn put( &self, - endpoint: &str, + endpoint: X, payload: &T, ) -> impl Future> + Send where @@ -122,14 +121,14 @@ pub trait RestClient: Debug { fn patch( &self, - endpoint: &str, + endpoint: X, payload: &T, ) -> impl Future> + Send where T: ?Sized + Serialize + Debug + Send + Sync, U: DeserializeOwned + Debug + Send + Sync; - fn delete(&self, endpoint: &str) -> impl Future> + Send; + fn delete(&self, endpoint: X) -> impl Future> + Send; } // ----------------------------------------------------------------------------- @@ -440,8 +439,6 @@ pub enum ClientError { Digest(SignerError), #[error("failed to serialize signature as header value, {0}")] SerializeHeaderValue(header::InvalidHeaderValue), - #[error("failed to parse url endpoint, {0}")] - ParseUrlEndpoint(url::ParseError), } // ----------------------------------------------------------------------------- @@ -457,69 +454,15 @@ pub struct Client { credentials: Option, } -impl Request for Client { +impl Execute for Client { type Error = ClientError; - #[cfg_attr(feature = "tracing", tracing::instrument)] - async fn request( - &self, - method: &Method, - endpoint: &str, - payload: &T, - ) -> Result - where - T: ?Sized + Serialize + Debug + Send + Sync, - U: DeserializeOwned + Debug + Send + Sync, - { - let buf = serde_json::to_vec(payload).map_err(ClientError::Serialize)?; - let mut request = reqwest::Request::new( - method.to_owned(), - endpoint.parse().map_err(ClientError::ParseUrlEndpoint)?, - ); - - request - .headers_mut() - .insert(header::CONTENT_TYPE, APPLICATION_JSON); - - request - .headers_mut() - .insert(header::CONTENT_LENGTH, HeaderValue::from(buf.len())); - - request.headers_mut().insert(header::ACCEPT_CHARSET, UTF8); - - request - .headers_mut() - .insert(header::ACCEPT, APPLICATION_JSON); - - *request.body_mut() = Some(buf.into()); - - let res = self.execute(request).await?; - let status = res.status(); - let buf = res.bytes().await.map_err(ClientError::BodyAggregation)?; - - #[cfg(feature = "logging")] - if log_enabled!(Level::Trace) { - trace!( - "received response, endpoint: '{endpoint}', method: '{method}', status: '{}'", - status.as_u16() - ); - } - - if !status.is_success() { - return Err(ClientError::StatusCode( - status, - serde_json::from_reader(buf.reader()).map_err(ClientError::Deserialize)?, - )); - } - - serde_json::from_reader(buf.reader()).map_err(ClientError::Deserialize) - } - + /// Executes the given HTTP `request`. #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))] fn execute( &self, mut request: reqwest::Request, - ) -> impl Future> + 'static { + ) -> impl Future> + Send + 'static { let client = self.clone(); async move { @@ -593,18 +536,63 @@ impl Request for Client { } } -impl RestClient for Client { - type Error = ClientError; +impl RestClient for Client { + #[cfg_attr(feature = "tracing", tracing::instrument)] + async fn request( + &self, + method: &Method, + endpoint: X, + payload: &T, + ) -> Result + where + T: ?Sized + Serialize + Debug + Send + Sync, + U: DeserializeOwned + Debug + Send + Sync, + { + let buf = serde_json::to_vec(payload).map_err(ClientError::Serialize)?; + + let url = endpoint.into_url().map_err(ClientError::Request)?; + + #[cfg(feature = "logging")] + let endpoint = url.as_str().to_owned(); + + let mut request = reqwest::Request::new(method.to_owned(), url); + + let headers = request.headers_mut(); + headers.insert(header::CONTENT_TYPE, APPLICATION_JSON); + headers.insert(header::CONTENT_LENGTH, HeaderValue::from(buf.len())); + headers.insert(header::ACCEPT_CHARSET, UTF8); + headers.insert(header::ACCEPT, APPLICATION_JSON); + + *request.body_mut() = Some(buf.into()); + + let res = self.execute(request).await?; + let status = res.status(); + let buf = res.bytes().await.map_err(ClientError::BodyAggregation)?; + + #[cfg(feature = "logging")] + trace!( + "received response, endpoint: '{endpoint}', method: '{method}', status: '{}'", + status.as_u16() + ); + + if !status.is_success() { + return Err(ClientError::StatusCode( + status, + serde_json::from_reader(buf.reader()).map_err(ClientError::Deserialize)?, + )); + } + + serde_json::from_reader(buf.reader()).map_err(ClientError::Deserialize) + } #[cfg_attr(feature = "tracing", tracing::instrument)] - async fn get(&self, endpoint: &str) -> Result + async fn get(&self, endpoint: X) -> Result where T: DeserializeOwned + Debug + Send + Sync, { - let mut req = reqwest::Request::new( - Method::GET, - endpoint.parse().map_err(ClientError::ParseUrlEndpoint)?, - ); + let url = endpoint.into_url().map_err(ClientError::Request)?; + + let mut req = reqwest::Request::new(Method::GET, url); req.headers_mut().insert(header::ACCEPT_CHARSET, UTF8); @@ -625,7 +613,7 @@ impl RestClient for Client { } #[cfg_attr(feature = "tracing", tracing::instrument)] - async fn post(&self, endpoint: &str, payload: &T) -> Result + async fn post(&self, endpoint: X, payload: &T) -> Result where T: ?Sized + Serialize + Debug + Send + Sync, U: DeserializeOwned + Debug + Send + Sync, @@ -634,7 +622,7 @@ impl RestClient for Client { } #[cfg_attr(feature = "tracing", tracing::instrument)] - async fn put(&self, endpoint: &str, payload: &T) -> Result + async fn put(&self, endpoint: X, payload: &T) -> Result where T: ?Sized + Serialize + Debug + Send + Sync, U: DeserializeOwned + Debug + Send + Sync, @@ -643,7 +631,7 @@ impl RestClient for Client { } #[cfg_attr(feature = "tracing", tracing::instrument)] - async fn patch(&self, endpoint: &str, payload: &T) -> Result + async fn patch(&self, endpoint: X, payload: &T) -> Result where T: ?Sized + Serialize + Debug + Send + Sync, U: DeserializeOwned + Debug + Send + Sync, @@ -652,11 +640,9 @@ impl RestClient for Client { } #[cfg_attr(feature = "tracing", tracing::instrument)] - async fn delete(&self, endpoint: &str) -> Result<(), Self::Error> { - let req = reqwest::Request::new( - Method::DELETE, - endpoint.parse().map_err(ClientError::ParseUrlEndpoint)?, - ); + async fn delete(&self, endpoint: X) -> Result<(), Self::Error> { + let url = endpoint.into_url().map_err(ClientError::Request)?; + let req = reqwest::Request::new(Method::DELETE, url); let res = self.execute(req).await?; let status = res.status(); diff --git a/src/client/sse.rs b/src/client/sse.rs index 91f6816..39b4726 100644 --- a/src/client/sse.rs +++ b/src/client/sse.rs @@ -28,7 +28,7 @@ use reqwest::{ use tracing::trace; use url::Url; -use super::Request; +use super::Execute; pub type SseErrorOf = SseError<::Error, ::Err, ::Err>; @@ -504,7 +504,7 @@ impl fmt::Debug for SseState { /// Stream of Server-Sent [`Event`]s. #[derive(Debug)] -pub struct SseStream { +pub struct SseStream { state: SseState, parser: EventParser, max_retry: Option<(u64, u64)>, @@ -513,7 +513,7 @@ pub struct SseStream { client: C, } -impl SseStream { +impl SseStream { pub fn builder(client: C, endpoint: U) -> SseStreamBuilder { SseStreamBuilder::new(client, endpoint.into_url()) } @@ -537,8 +537,8 @@ impl SseStream { } } -impl Stream for SseStream { - type Item = Result, SseError>; +impl Stream for SseStream { + type Item = SseResult; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = &mut *self; @@ -660,7 +660,7 @@ impl Stream for SseStream { } } -impl FusedStream for SseStream { +impl FusedStream for SseStream { fn is_terminated(&self) -> bool { matches!(self.state, SseState::Closed) } @@ -766,7 +766,7 @@ impl SseStreamBuilder { #[cfg_attr(feature = "tracing", tracing::instrument)] pub fn stream(self) -> SseBuildResult where - C: Request + fmt::Debug, + C: Execute + fmt::Debug, K: FromStr, V: FromStr, { @@ -856,7 +856,7 @@ pub trait SseClient { } } -impl SseClient for T { +impl SseClient for T { fn sse(&self, endpoint: U) -> SseStreamBuilder where K: FromStr + fmt::Debug + Send + 'static, From 25dea051ab692825bb2cf5fab6090ce92caadfa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Lemaire-Giroud?= Date: Wed, 21 May 2025 10:39:40 +0200 Subject: [PATCH 4/6] chore: normalize imports and derives --- src/client/mod.rs | 67 +++++++++++++++++++++++------------------------ src/client/sse.rs | 6 ++--- 2 files changed, 36 insertions(+), 37 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 5681921..de053be 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -4,11 +4,10 @@ //! to interact with the Clever-Cloud's api, but has been extended to be more //! generic. +use core::{error::Error, fmt, future::Future}; + use std::{ collections::BTreeMap, - error::Error, - fmt::{self, Debug, Display, Formatter}, - future::Future, time::{SystemTime, SystemTimeError}, }; #[cfg(feature = "metrics")] @@ -94,12 +93,12 @@ pub trait RestClient: Execute { payload: &T, ) -> impl Future> + Send where - T: ?Sized + Serialize + Debug + Send + Sync, - U: DeserializeOwned + Debug + Send + Sync; + T: ?Sized + Serialize + fmt::Debug + Send + Sync, + U: DeserializeOwned + fmt::Debug + Send + Sync; - fn get(&self, endpoint: X) -> impl Future> + Send + fn get(&self, endpoint: X) -> impl Future> + Send where - T: DeserializeOwned + Debug + Send + Sync; + U: DeserializeOwned + fmt::Debug + Send + Sync; fn post( &self, @@ -107,8 +106,8 @@ pub trait RestClient: Execute { payload: &T, ) -> impl Future> + Send where - T: ?Sized + Serialize + Debug + Send + Sync, - U: DeserializeOwned + Debug + Send + Sync; + T: ?Sized + Serialize + fmt::Debug + Send + Sync, + U: DeserializeOwned + fmt::Debug + Send + Sync; fn put( &self, @@ -116,8 +115,8 @@ pub trait RestClient: Execute { payload: &T, ) -> impl Future> + Send where - T: ?Sized + Serialize + Debug + Send + Sync, - U: DeserializeOwned + Debug + Send + Sync; + T: ?Sized + Serialize + fmt::Debug + Send + Sync, + U: DeserializeOwned + fmt::Debug + Send + Sync; fn patch( &self, @@ -125,16 +124,16 @@ pub trait RestClient: Execute { payload: &T, ) -> impl Future> + Send where - T: ?Sized + Serialize + Debug + Send + Sync, - U: DeserializeOwned + Debug + Send + Sync; + T: ?Sized + Serialize + fmt::Debug + Send + Sync, + U: DeserializeOwned + fmt::Debug + Send + Sync; fn delete(&self, endpoint: X) -> impl Future> + Send; } // ----------------------------------------------------------------------------- -// ClientCredentials structure +// Credentials structure -#[derive(Serialize, Deserialize, PartialEq, Eq, Clone)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(untagged)] pub enum Credentials { OAuth1 { @@ -160,7 +159,7 @@ pub enum Credentials { } impl fmt::Debug for Credentials { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // NOTE: ensure secrets are not leaked in logs match self { Self::OAuth1 { .. } => f.write_str("OAuth1"), @@ -221,7 +220,7 @@ pub const OAUTH1_VERSION: &str = "oauth_version"; pub const OAUTH1_VERSION_1: &str = "1.0"; pub const OAUTH1_TOKEN: &str = "oauth_token"; -pub trait OAuth1: Debug { +pub trait OAuth1: fmt::Debug { type Error; // `params` returns OAuth1 parameters without the signature one @@ -258,7 +257,7 @@ pub trait OAuth1: Debug { // ----------------------------------------------------------------------------- // ResponseError structure -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ResponseError { #[serde(rename = "id")] pub id: u32, @@ -268,8 +267,8 @@ pub struct ResponseError { pub kind: String, } -impl Display for ResponseError { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { +impl fmt::Display for ResponseError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "got response {} ({}), {}", @@ -283,7 +282,7 @@ impl Error for ResponseError {} // ----------------------------------------------------------------------------- // SignerError enum -#[derive(thiserror::Error, Debug)] +#[derive(Debug, thiserror::Error)] pub enum SignerError { #[error("failed to compute invalid key length, {0}")] Digest(InvalidLength), @@ -300,7 +299,7 @@ pub enum SignerError { // ----------------------------------------------------------------------------- // Signer structure -#[derive(PartialEq, Eq, Clone, Debug)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Signer { pub nonce: String, pub timestamp: u64, @@ -421,7 +420,7 @@ impl TryFrom for Signer { // ----------------------------------------------------------------------------- // ClientError enum -#[derive(thiserror::Error, Debug)] +#[derive(Debug, thiserror::Error)] pub enum ClientError { #[error("failed to execute request, {0}")] Request(reqwest::Error), @@ -448,7 +447,7 @@ pub const APPLICATION_JSON: HeaderValue = HeaderValue::from_static("application/ pub const UTF8: HeaderValue = HeaderValue::from_static("utf-8"); -#[derive(Clone, Debug)] +#[derive(Debug, Clone)] pub struct Client { inner: reqwest::Client, credentials: Option, @@ -545,8 +544,8 @@ impl RestClient for Client { payload: &T, ) -> Result where - T: ?Sized + Serialize + Debug + Send + Sync, - U: DeserializeOwned + Debug + Send + Sync, + T: ?Sized + Serialize + fmt::Debug + Send + Sync, + U: DeserializeOwned + fmt::Debug + Send + Sync, { let buf = serde_json::to_vec(payload).map_err(ClientError::Serialize)?; @@ -586,9 +585,9 @@ impl RestClient for Client { } #[cfg_attr(feature = "tracing", tracing::instrument)] - async fn get(&self, endpoint: X) -> Result + async fn get(&self, endpoint: X) -> Result where - T: DeserializeOwned + Debug + Send + Sync, + U: DeserializeOwned + fmt::Debug + Send + Sync, { let url = endpoint.into_url().map_err(ClientError::Request)?; @@ -615,8 +614,8 @@ impl RestClient for Client { #[cfg_attr(feature = "tracing", tracing::instrument)] async fn post(&self, endpoint: X, payload: &T) -> Result where - T: ?Sized + Serialize + Debug + Send + Sync, - U: DeserializeOwned + Debug + Send + Sync, + T: ?Sized + Serialize + fmt::Debug + Send + Sync, + U: DeserializeOwned + fmt::Debug + Send + Sync, { self.request(&Method::POST, endpoint, payload).await } @@ -624,8 +623,8 @@ impl RestClient for Client { #[cfg_attr(feature = "tracing", tracing::instrument)] async fn put(&self, endpoint: X, payload: &T) -> Result where - T: ?Sized + Serialize + Debug + Send + Sync, - U: DeserializeOwned + Debug + Send + Sync, + T: ?Sized + Serialize + fmt::Debug + Send + Sync, + U: DeserializeOwned + fmt::Debug + Send + Sync, { self.request(&Method::PUT, endpoint, payload).await } @@ -633,8 +632,8 @@ impl RestClient for Client { #[cfg_attr(feature = "tracing", tracing::instrument)] async fn patch(&self, endpoint: X, payload: &T) -> Result where - T: ?Sized + Serialize + Debug + Send + Sync, - U: DeserializeOwned + Debug + Send + Sync, + T: ?Sized + Serialize + fmt::Debug + Send + Sync, + U: DeserializeOwned + fmt::Debug + Send + Sync, { self.request(&Method::PATCH, endpoint, payload).await } diff --git a/src/client/sse.rs b/src/client/sse.rs index 39b4726..aeb401f 100644 --- a/src/client/sse.rs +++ b/src/client/sse.rs @@ -111,7 +111,7 @@ pub struct Event { // EVENT PARSER //////////////////////////////////////////////////////////////// -#[derive(Copy, Clone, Debug, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq)] enum Eol { // carriage return (`\r`) or line feed (`\n`) CrOrLf = 1, @@ -378,7 +378,7 @@ impl fmt::Display for EventId { /// /// Specialized [`fmt::Display`] implementation that escapes newlines in serial /// JSON representation for Server-Sent Events (SSE) streaming compatibility. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct Json(pub T); impl PartialEq for Json { @@ -682,7 +682,7 @@ pub struct SseStreamBuilder { } impl fmt::Debug for SseStreamBuilder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SseStreamBuilder") .field("client", &self.client) .field("endpoint", &self.endpoint) From 6492e86345d0443388cd7f8a3b59b9323234ac75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Lemaire-Giroud?= Date: Wed, 21 May 2025 10:49:09 +0200 Subject: [PATCH 5/6] chore: simplify OAuth1 signature --- src/client/mod.rs | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index de053be..2b78be5 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -333,30 +333,24 @@ impl OAuth1 for Signer { #[cfg_attr(feature = "tracing", tracing::instrument)] fn signature(&self, method: &str, endpoint: &str) -> Result { - let (host, query) = match endpoint.find('?') { - None => (endpoint, ""), - // split one character further to not get the '?' character - Some(position) => endpoint.split_at(position), - }; - - let query = query.strip_prefix('?').unwrap_or(query); let mut params = self.params(); - if !query.is_empty() { - for qparam in query.split('&') { - let (k, v) = qparam.split_at(qparam.find('=').ok_or_else(|| { - SignerError::Parse(format!("failed to parse query parameter, {qparam}")) - })?); - - if !params.contains_key(k) { - params.insert(k.to_string(), v.strip_prefix('=').unwrap_or(v).to_owned()); + let host = match endpoint.split_once('?') { + None => endpoint, + Some((host, query)) => { + for qparam in query.split('&') { + let (k, v) = qparam.split_once('=').ok_or_else(|| { + SignerError::Parse(format!("failed to parse query parameter, {qparam}")) + })?; + params.entry(k.to_owned()).or_insert(v.to_owned()); } + host } - } + }; let mut params = params .iter() - .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v))) + .map(|(k, v)| format!("{k}={}", urlencoding::encode(v))) .collect::>(); params.sort(); From c1fad59f546753ba7d040a35420587eb090da5e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Lemaire-Giroud?= Date: Wed, 21 May 2025 19:30:48 +0200 Subject: [PATCH 6/6] chore: catch panics on too many headers and prevent unnecessary allocations --- src/client/mod.rs | 198 ++++++++++++++++++++++++++-------------------- src/client/sse.rs | 23 +++--- 2 files changed, 126 insertions(+), 95 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 2b78be5..e4bcdd0 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -4,9 +4,10 @@ //! to interact with the Clever-Cloud's api, but has been extended to be more //! generic. -use core::{error::Error, fmt, future::Future}; +use core::{error::Error, fmt, future::Future, time::Duration}; use std::{ + borrow::Cow, collections::BTreeMap, time::{SystemTime, SystemTimeError}, }; @@ -224,7 +225,7 @@ pub trait OAuth1: fmt::Debug { type Error; // `params` returns OAuth1 parameters without the signature one - fn params(&self) -> BTreeMap; + fn params(&self) -> BTreeMap, Cow<'_, str>>; // `signature` returns the computed signature from given parameters fn signature(&self, method: &str, endpoint: &str) -> Result; @@ -238,9 +239,9 @@ pub trait OAuth1: fmt::Debug { let signature = self.signature(method, endpoint)?; let mut params = self.params(); - params.insert( - OAUTH1_SIGNATURE.to_string(), - urlencoding::encode(&signature).into_owned(), + let _ = params.insert( + Cow::Borrowed(OAUTH1_SIGNATURE), + urlencoding::encode(&signature), ); let mut base = params @@ -299,35 +300,73 @@ pub enum SignerError { // ----------------------------------------------------------------------------- // Signer structure -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Signer { - pub nonce: String, - pub timestamp: u64, - pub token: String, - pub secret: String, - pub consumer_key: String, - pub consumer_secret: String, +#[derive(Clone, PartialEq, Eq)] +pub struct Signer { + pub nonce: Uuid, + pub timestamp: Duration, + pub token: T, + pub secret: T, + pub consumer_key: T, + pub consumer_secret: T, } -impl OAuth1 for Signer { +impl fmt::Debug for Signer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Signer") + .field("nonce", &self.nonce) + .field("timestamp", &self.timestamp) + .finish_non_exhaustive() + } +} + +impl Signer { + #[cfg_attr(feature = "tracing", tracing::instrument)] + fn new(token: T, secret: T, consumer_key: T, consumer_secret: T) -> Result { + let nonce = Uuid::new_v4(); + let timestamp = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_err(SignerError::UnixEpochTime)?; + Ok(Self { + nonce, + timestamp, + token, + secret, + consumer_key, + consumer_secret, + }) + } +} + +impl + fmt::Debug> OAuth1 for Signer { type Error = SignerError; #[cfg_attr(feature = "tracing", tracing::instrument)] - fn params(&self) -> BTreeMap { + fn params(&self) -> BTreeMap, Cow<'_, str>> { let mut params = BTreeMap::new(); - - params.insert( - OAUTH1_CONSUMER_KEY.to_string(), - self.consumer_key.to_string(), + let _ = params.insert( + Cow::Borrowed(OAUTH1_CONSUMER_KEY), + self.consumer_key.as_ref().into(), + ); + let _ = params.insert( + Cow::Borrowed(OAUTH1_NONCE), + Cow::Owned(self.nonce.to_string()), + ); + let _ = params.insert( + Cow::Borrowed(OAUTH1_SIGNATURE_METHOD), + Cow::Borrowed(OAUTH1_SIGNATURE_HMAC_SHA512), ); - params.insert(OAUTH1_NONCE.to_string(), self.nonce.to_string()); - params.insert( - OAUTH1_SIGNATURE_METHOD.to_string(), - OAUTH1_SIGNATURE_HMAC_SHA512.to_string(), + let _ = params.insert( + Cow::Borrowed(OAUTH1_TIMESTAMP), + Cow::Owned(self.timestamp.as_secs().to_string()), + ); + let _ = params.insert( + Cow::Borrowed(OAUTH1_VERSION), + Cow::Borrowed(OAUTH1_VERSION_1), + ); + let _ = params.insert( + Cow::Borrowed(OAUTH1_TOKEN), + Cow::Borrowed(self.token.as_ref()), ); - params.insert(OAUTH1_TIMESTAMP.to_string(), self.timestamp.to_string()); - params.insert(OAUTH1_VERSION.to_string(), OAUTH1_VERSION_1.to_string()); - params.insert(OAUTH1_TOKEN.to_string(), self.token.to_string()); params } @@ -335,6 +374,8 @@ impl OAuth1 for Signer { fn signature(&self, method: &str, endpoint: &str) -> Result { let mut params = self.params(); + // TODO: we could use query_pairs on Url + let host = match endpoint.split_once('?') { None => endpoint, Some((host, query)) => { @@ -342,7 +383,7 @@ impl OAuth1 for Signer { let (k, v) = qparam.split_once('=').ok_or_else(|| { SignerError::Parse(format!("failed to parse query parameter, {qparam}")) })?; - params.entry(k.to_owned()).or_insert(v.to_owned()); + let _ = params.entry(Cow::Borrowed(k)).or_insert(Cow::Borrowed(v)); } host } @@ -375,8 +416,8 @@ impl OAuth1 for Signer { fn signing_key(&self) -> String { format!( "{}&{}", - urlencoding::encode(&self.consumer_secret), - urlencoding::encode(&self.secret) + urlencoding::encode(self.consumer_secret.as_ref()), + urlencoding::encode(self.secret.as_ref()) ) } } @@ -386,26 +427,13 @@ impl TryFrom for Signer { #[cfg_attr(feature = "tracing", tracing::instrument)] fn try_from(credentials: Credentials) -> Result { - let nonce = Uuid::new_v4().to_string(); - let timestamp = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map_err(SignerError::UnixEpochTime)? - .as_secs(); - match credentials { Credentials::OAuth1 { token, secret, consumer_key, consumer_secret, - } => Ok(Self { - nonce, - timestamp, - token, - secret, - consumer_key, - consumer_secret, - }), + } => Self::new(token, secret, consumer_key, consumer_secret), _ => Err(SignerError::InvalidCredentials), } } @@ -432,6 +460,8 @@ pub enum ClientError { Digest(SignerError), #[error("failed to serialize signature as header value, {0}")] SerializeHeaderValue(header::InvalidHeaderValue), + #[error("failed to insert header in request: too many entries")] + TooManyHeaders(#[from] reqwest::header::MaxSizeReached), } // ----------------------------------------------------------------------------- @@ -462,36 +492,33 @@ impl Execute for Client { let method = request.method().to_string(); let endpoint = request.url().to_string(); - if !request.headers().contains_key(&header::AUTHORIZATION) { - match &client.credentials { - Some(Credentials::Bearer { token }) => { - request.headers_mut().insert( - header::AUTHORIZATION, + if let Some(credentials) = &client.credentials { + if let header::Entry::Vacant(vacant_entry) = + request.headers_mut().entry(header::AUTHORIZATION) + { + let header_value = match credentials { + Credentials::OAuth1 { + token, + secret, + consumer_key, + consumer_secret, + } => Signer::new(token, secret, consumer_key, consumer_secret) + .map_err(ClientError::Signer)? + .sign(&method, &endpoint) + .map_err(ClientError::Digest)? + .parse() + .map_err(ClientError::SerializeHeaderValue)?, + Credentials::Basic { username, password } => { + let token = BASE64_ENGINE.encode(format!("{username}:{password}")); + HeaderValue::from_str(&format!("Basic {token}")) + .map_err(ClientError::SerializeHeaderValue)? + } + Credentials::Bearer { token } => { HeaderValue::from_str(&format!("Bearer {token}")) - .map_err(ClientError::SerializeHeaderValue)?, - ); - } - Some(Credentials::Basic { username, password }) => { - let token = BASE64_ENGINE.encode(format!("{username}:{password}")); - - request.headers_mut().insert( - header::AUTHORIZATION, - HeaderValue::from_str(&format!("Basic {token}",)) - .map_err(ClientError::SerializeHeaderValue)?, - ); - } - Some(credentials) => { - request.headers_mut().insert( - header::AUTHORIZATION, - Signer::try_from(credentials.to_owned()) - .map_err(ClientError::Signer)? - .sign(&method, &endpoint) - .map_err(ClientError::Digest)? - .parse() - .map_err(ClientError::SerializeHeaderValue)?, - ); - } - _ => {} + .map_err(ClientError::SerializeHeaderValue)? + } + }; + let _ = vacant_entry.try_insert(header_value)?; } } @@ -551,10 +578,10 @@ impl RestClient for Client { let mut request = reqwest::Request::new(method.to_owned(), url); let headers = request.headers_mut(); - headers.insert(header::CONTENT_TYPE, APPLICATION_JSON); - headers.insert(header::CONTENT_LENGTH, HeaderValue::from(buf.len())); - headers.insert(header::ACCEPT_CHARSET, UTF8); - headers.insert(header::ACCEPT, APPLICATION_JSON); + let _ = headers.try_insert(header::CONTENT_TYPE, APPLICATION_JSON)?; + let _ = headers.try_insert(header::CONTENT_LENGTH, HeaderValue::from(buf.len()))?; + let _ = headers.try_insert(header::ACCEPT_CHARSET, UTF8)?; + let _ = headers.try_insert(header::ACCEPT, APPLICATION_JSON)?; *request.body_mut() = Some(buf.into()); @@ -585,15 +612,18 @@ impl RestClient for Client { { let url = endpoint.into_url().map_err(ClientError::Request)?; - let mut req = reqwest::Request::new(Method::GET, url); - - req.headers_mut().insert(header::ACCEPT_CHARSET, UTF8); - - req.headers_mut().insert(header::ACCEPT, APPLICATION_JSON); + let mut request = reqwest::Request::new(Method::GET, url); - let res = self.execute(req).await?; - let status = res.status(); - let buf = res.bytes().await.map_err(ClientError::BodyAggregation)?; + let headers = request.headers_mut(); + let _ = headers.try_insert(header::ACCEPT_CHARSET, UTF8)?; + let _ = headers.try_insert(header::ACCEPT, APPLICATION_JSON)?; + + let response = self.execute(request).await?; + let status = response.status(); + let buf = response + .bytes() + .await + .map_err(ClientError::BodyAggregation)?; if !status.is_success() { return Err(ClientError::StatusCode( diff --git a/src/client/sse.rs b/src/client/sse.rs index aeb401f..547140c 100644 --- a/src/client/sse.rs +++ b/src/client/sse.rs @@ -38,6 +38,8 @@ pub type SseResult = Result, SseErrorOf = Result, SseErrorOf>; +pub const MAX_CAPACITY: usize = isize::MAX as usize; + /// Default initial capacity of the buffer of the [`SseStream`]. pub const DEFAULT_INITIAL_CAPACITY: usize = 512; @@ -154,7 +156,7 @@ impl EventParser { initial_capacity: usize, max_capacity: usize, ) -> Self { - let max_capacity = max_capacity.min(isize::MAX as usize); + let max_capacity = max_capacity.min(MAX_CAPACITY); let initial_capacity = initial_capacity.min(max_capacity); Self { buf: BytesMut::with_capacity(initial_capacity), @@ -475,6 +477,8 @@ pub enum SseError { /// Failed to parse [`Event`]. #[error(transparent)] Parser(EventParseError), + #[error("too many header")] + TooManyHeaders(#[from] reqwest::header::MaxSizeReached), } // SSE STATE /////////////////////////////////////////////////////////////////// @@ -785,10 +789,8 @@ impl SseStreamBuilder { let mut request = reqwest::Request::new(Method::GET, url); let headers = request.headers_mut(); - - let _ = headers.insert(header::ACCEPT, TEXT_EVENT_STREAM); - - let _ = headers.insert(header::CACHE_CONTROL, NO_STORE); + let _ = headers.try_insert(header::ACCEPT, TEXT_EVENT_STREAM)?; + let _ = headers.try_insert(header::CACHE_CONTROL, NO_STORE)?; if let Some(last_event_id) = last_event_id.header_value() { let _ = headers.insert(LAST_EVENT_ID, last_event_id); @@ -796,10 +798,9 @@ impl SseStreamBuilder { // TODO: request's "initiator" type should be set to "other" - let first_request = match request.try_clone() { - None => return Err(SseError::RequestBodyNotCloneable), - Some(v) => v, - }; + let first_request = request + .try_clone() + .ok_or(SseError::RequestBodyNotCloneable)?; Ok(SseStream { state: SseState::Connecting(client.execute(first_request).boxed()), @@ -807,7 +808,7 @@ impl SseStreamBuilder { request.url().clone(), last_event_id, initial_capacity, - max_capacity.unwrap_or(isize::MAX as usize), + max_capacity.unwrap_or(MAX_CAPACITY), ), max_retry: max_retry.map(|n| (n, n)), max_loop: max_loop.map(|n| (n, n)), @@ -819,7 +820,7 @@ impl SseStreamBuilder { // SSE CLIENT ////////////////////////////////////////////////////////////////// -/// Extension trait for [`Request`]s clients that support subscribing to Server-Sent Events (SSE). +/// Extension trait for HTTP clients that support subscribing to Server-Sent Events (SSE). pub trait SseClient { /// Sends a GET HTTP request to the provided `endpoint`, /// which is expected to serve a stream of Server-Sent Events (SSE).