Skip to content

Commit a063b17

Browse files
committed
* feat(sse): add SseErrorOf type alias
* feat: let endpoint argument be any type that impl reqwest::IntoUrl moved request to RestClient trait for IMHO better separation of concerns moved execute to Execute trait for better type inference * refactor: simplify qparams parsing * security: ensure credentials are not leaked to logs
1 parent 69e2f8d commit a063b17

File tree

2 files changed

+110
-120
lines changed

2 files changed

+110
-120
lines changed

src/client/mod.rs

Lines changed: 95 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ use bytes::Buf;
1919
use crypto_common::InvalidLength;
2020
use hmac::{Hmac, Mac};
2121
#[cfg(feature = "logging")]
22-
use log::{Level, error, log_enabled, trace};
22+
use log::{error, trace};
2323
#[cfg(feature = "metrics")]
2424
use prometheus::{CounterVec, opts, register_counter_vec};
2525
use reqwest::{
26-
Method, StatusCode,
26+
IntoUrl, Method, StatusCode,
2727
header::{self, HeaderValue},
2828
};
2929
use serde::{Deserialize, Serialize, de::DeserializeOwned};
@@ -71,40 +71,26 @@ static CLIENT_REQUEST_DURATION: LazyLock<CounterVec> = LazyLock::new(|| {
7171
type HmacSha512 = Hmac<Sha512>;
7272

7373
// -----------------------------------------------------------------------------
74-
// Request trait
75-
76-
pub trait Request {
77-
type Error;
74+
// RestClient trait
7875

76+
pub trait RestClient<X>: Execute {
7977
fn request<T, U>(
8078
&self,
8179
method: &Method,
82-
endpoint: &str,
80+
endpoint: X,
8381
payload: &T,
84-
) -> impl Future<Output = Result<U, Self::Error>> + Send
82+
) -> impl Future<Output = Result<U, <Self as Execute>::Error>> + Send
8583
where
8684
T: ?Sized + Serialize + Debug + Send + Sync,
8785
U: DeserializeOwned + Debug + Send + Sync;
8886

89-
fn execute(
90-
&self,
91-
request: reqwest::Request,
92-
) -> impl Future<Output = Result<reqwest::Response, Self::Error>> + Send + 'static;
93-
}
94-
95-
// -----------------------------------------------------------------------------
96-
// RestClient trait
97-
98-
pub trait RestClient: Debug {
99-
type Error;
100-
101-
fn get<T>(&self, endpoint: &str) -> impl Future<Output = Result<T, Self::Error>> + Send
87+
fn get<T>(&self, endpoint: X) -> impl Future<Output = Result<T, Self::Error>> + Send
10288
where
10389
T: DeserializeOwned + Debug + Send + Sync;
10490

10591
fn post<T, U>(
10692
&self,
107-
endpoint: &str,
93+
endpoint: X,
10894
payload: &T,
10995
) -> impl Future<Output = Result<U, Self::Error>> + Send
11096
where
@@ -113,7 +99,7 @@ pub trait RestClient: Debug {
11399

114100
fn put<T, U>(
115101
&self,
116-
endpoint: &str,
102+
endpoint: X,
117103
payload: &T,
118104
) -> impl Future<Output = Result<U, Self::Error>> + Send
119105
where
@@ -122,20 +108,20 @@ pub trait RestClient: Debug {
122108

123109
fn patch<T, U>(
124110
&self,
125-
endpoint: &str,
111+
endpoint: X,
126112
payload: &T,
127113
) -> impl Future<Output = Result<U, Self::Error>> + Send
128114
where
129115
T: ?Sized + Serialize + Debug + Send + Sync,
130116
U: DeserializeOwned + Debug + Send + Sync;
131117

132-
fn delete(&self, endpoint: &str) -> impl Future<Output = Result<(), Self::Error>> + Send;
118+
fn delete(&self, endpoint: X) -> impl Future<Output = Result<(), Self::Error>> + Send;
133119
}
134120

135121
// -----------------------------------------------------------------------------
136122
// ClientCredentials structure
137123

138-
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
124+
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone)]
139125
#[serde(untagged)]
140126
pub enum Credentials {
141127
OAuth1 {
@@ -160,6 +146,17 @@ pub enum Credentials {
160146
},
161147
}
162148

149+
impl fmt::Debug for Credentials {
150+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
151+
// NOTE: ensure secrets are not leaked in logs
152+
match self {
153+
Self::OAuth1 { .. } => f.write_str("OAuth1"),
154+
Self::Basic { .. } => f.write_str("Basic"),
155+
Self::Bearer { .. } => f.write_str("Bearer"),
156+
}
157+
}
158+
}
159+
163160
impl Default for Credentials {
164161
fn default() -> Self {
165162
Self::OAuth1 {
@@ -335,19 +332,16 @@ impl OAuth1 for Signer {
335332

336333
if !query.is_empty() {
337334
for qparam in query.split('&') {
338-
let (k, v) = qparam.split_at(qparam.find('=').ok_or_else(|| {
335+
let (k, v) = qparam.split_once("=").ok_or_else(|| {
339336
SignerError::Parse(format!("failed to parse query parameter, {qparam}"))
340-
})?);
341-
342-
if !params.contains_key(k) {
343-
params.insert(k.to_string(), v.strip_prefix('=').unwrap_or(v).to_owned());
344-
}
337+
})?;
338+
params.entry(k.to_owned()).or_insert(v.to_owned());
345339
}
346340
}
347341

348342
let mut params = params
349343
.iter()
350-
.map(|(k, v)| format!("{}={}", k, urlencoding::encode(v)))
344+
.map(|(k, v)| format!("{k}={}", urlencoding::encode(v)))
351345
.collect::<Vec<_>>();
352346

353347
params.sort();
@@ -429,8 +423,6 @@ pub enum ClientError {
429423
Digest(SignerError),
430424
#[error("failed to serialize signature as header value, {0}")]
431425
SerializeHeaderValue(header::InvalidHeaderValue),
432-
#[error("failed to parse url endpoint, {0}")]
433-
ParseUrlEndpoint(url::ParseError),
434426
}
435427

436428
// -----------------------------------------------------------------------------
@@ -446,69 +438,24 @@ pub struct Client {
446438
credentials: Option<Credentials>,
447439
}
448440

449-
impl Request for Client {
450-
type Error = ClientError;
441+
pub trait Execute {
442+
type Error;
451443

452-
#[cfg_attr(feature = "tracing", tracing::instrument)]
453-
async fn request<T, U>(
444+
fn execute(
454445
&self,
455-
method: &Method,
456-
endpoint: &str,
457-
payload: &T,
458-
) -> Result<U, Self::Error>
459-
where
460-
T: ?Sized + Serialize + Debug + Send + Sync,
461-
U: DeserializeOwned + Debug + Send + Sync,
462-
{
463-
let buf = serde_json::to_vec(payload).map_err(ClientError::Serialize)?;
464-
let mut request = reqwest::Request::new(
465-
method.to_owned(),
466-
endpoint.parse().map_err(ClientError::ParseUrlEndpoint)?,
467-
);
468-
469-
request
470-
.headers_mut()
471-
.insert(header::CONTENT_TYPE, APPLICATION_JSON);
472-
473-
request
474-
.headers_mut()
475-
.insert(header::CONTENT_LENGTH, HeaderValue::from(buf.len()));
476-
477-
request.headers_mut().insert(header::ACCEPT_CHARSET, UTF8);
478-
479-
request
480-
.headers_mut()
481-
.insert(header::ACCEPT, APPLICATION_JSON);
482-
483-
*request.body_mut() = Some(buf.into());
484-
485-
let res = self.execute(request).await?;
486-
let status = res.status();
487-
let buf = res.bytes().await.map_err(ClientError::BodyAggregation)?;
488-
489-
#[cfg(feature = "logging")]
490-
if log_enabled!(Level::Trace) {
491-
trace!(
492-
"received response, endpoint: '{endpoint}', method: '{method}', status: '{}'",
493-
status.as_u16()
494-
);
495-
}
496-
497-
if !status.is_success() {
498-
return Err(ClientError::StatusCode(
499-
status,
500-
serde_json::from_reader(buf.reader()).map_err(ClientError::Deserialize)?,
501-
));
502-
}
446+
request: reqwest::Request,
447+
) -> impl Future<Output = Result<reqwest::Response, Self::Error>> + Send + 'static;
448+
}
503449

504-
serde_json::from_reader(buf.reader()).map_err(ClientError::Deserialize)
505-
}
450+
impl Execute for Client {
451+
type Error = ClientError;
506452

453+
/// Executes the given HTTP `request`.
507454
#[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
508455
fn execute(
509456
&self,
510457
mut request: reqwest::Request,
511-
) -> impl Future<Output = Result<reqwest::Response, Self::Error>> + 'static {
458+
) -> impl Future<Output = Result<reqwest::Response, Self::Error>> + Send + 'static {
512459
let client = self.clone();
513460

514461
async move {
@@ -582,18 +529,63 @@ impl Request for Client {
582529
}
583530
}
584531

585-
impl RestClient for Client {
586-
type Error = ClientError;
532+
impl<X: IntoUrl + fmt::Debug + Send> RestClient<X> for Client {
533+
#[cfg_attr(feature = "tracing", tracing::instrument)]
534+
async fn request<T, U>(
535+
&self,
536+
method: &Method,
537+
endpoint: X,
538+
payload: &T,
539+
) -> Result<U, Self::Error>
540+
where
541+
T: ?Sized + Serialize + Debug + Send + Sync,
542+
U: DeserializeOwned + Debug + Send + Sync,
543+
{
544+
let buf = serde_json::to_vec(payload).map_err(ClientError::Serialize)?;
545+
546+
let url = endpoint.into_url().map_err(ClientError::Request)?;
547+
548+
#[cfg(feature = "logging")]
549+
let endpoint = url.as_str().to_owned();
550+
551+
let mut request = reqwest::Request::new(method.to_owned(), url);
552+
553+
let headers = request.headers_mut();
554+
headers.insert(header::CONTENT_TYPE, APPLICATION_JSON);
555+
headers.insert(header::CONTENT_LENGTH, HeaderValue::from(buf.len()));
556+
headers.insert(header::ACCEPT_CHARSET, UTF8);
557+
headers.insert(header::ACCEPT, APPLICATION_JSON);
558+
559+
*request.body_mut() = Some(buf.into());
560+
561+
let res = self.execute(request).await?;
562+
let status = res.status();
563+
let buf = res.bytes().await.map_err(ClientError::BodyAggregation)?;
564+
565+
#[cfg(feature = "logging")]
566+
trace!(
567+
"received response, endpoint: '{endpoint}', method: '{method}', status: '{}'",
568+
status.as_u16()
569+
);
570+
571+
if !status.is_success() {
572+
return Err(ClientError::StatusCode(
573+
status,
574+
serde_json::from_reader(buf.reader()).map_err(ClientError::Deserialize)?,
575+
));
576+
}
577+
578+
serde_json::from_reader(buf.reader()).map_err(ClientError::Deserialize)
579+
}
587580

588581
#[cfg_attr(feature = "tracing", tracing::instrument)]
589-
async fn get<T>(&self, endpoint: &str) -> Result<T, Self::Error>
582+
async fn get<T>(&self, endpoint: X) -> Result<T, Self::Error>
590583
where
591584
T: DeserializeOwned + Debug + Send + Sync,
592585
{
593-
let mut req = reqwest::Request::new(
594-
Method::GET,
595-
endpoint.parse().map_err(ClientError::ParseUrlEndpoint)?,
596-
);
586+
let url = endpoint.into_url().map_err(ClientError::Request)?;
587+
588+
let mut req = reqwest::Request::new(Method::GET, url);
597589

598590
req.headers_mut().insert(header::ACCEPT_CHARSET, UTF8);
599591

@@ -614,7 +606,7 @@ impl RestClient for Client {
614606
}
615607

616608
#[cfg_attr(feature = "tracing", tracing::instrument)]
617-
async fn post<T, U>(&self, endpoint: &str, payload: &T) -> Result<U, Self::Error>
609+
async fn post<T, U>(&self, endpoint: X, payload: &T) -> Result<U, Self::Error>
618610
where
619611
T: ?Sized + Serialize + Debug + Send + Sync,
620612
U: DeserializeOwned + Debug + Send + Sync,
@@ -623,7 +615,7 @@ impl RestClient for Client {
623615
}
624616

625617
#[cfg_attr(feature = "tracing", tracing::instrument)]
626-
async fn put<T, U>(&self, endpoint: &str, payload: &T) -> Result<U, Self::Error>
618+
async fn put<T, U>(&self, endpoint: X, payload: &T) -> Result<U, Self::Error>
627619
where
628620
T: ?Sized + Serialize + Debug + Send + Sync,
629621
U: DeserializeOwned + Debug + Send + Sync,
@@ -632,7 +624,7 @@ impl RestClient for Client {
632624
}
633625

634626
#[cfg_attr(feature = "tracing", tracing::instrument)]
635-
async fn patch<T, U>(&self, endpoint: &str, payload: &T) -> Result<U, Self::Error>
627+
async fn patch<T, U>(&self, endpoint: X, payload: &T) -> Result<U, Self::Error>
636628
where
637629
T: ?Sized + Serialize + Debug + Send + Sync,
638630
U: DeserializeOwned + Debug + Send + Sync,
@@ -641,11 +633,9 @@ impl RestClient for Client {
641633
}
642634

643635
#[cfg_attr(feature = "tracing", tracing::instrument)]
644-
async fn delete(&self, endpoint: &str) -> Result<(), Self::Error> {
645-
let req = reqwest::Request::new(
646-
Method::DELETE,
647-
endpoint.parse().map_err(ClientError::ParseUrlEndpoint)?,
648-
);
636+
async fn delete(&self, endpoint: X) -> Result<(), Self::Error> {
637+
let url = endpoint.into_url().map_err(ClientError::Request)?;
638+
let req = reqwest::Request::new(Method::DELETE, url);
649639

650640
let res = self.execute(req).await?;
651641
let status = res.status();

0 commit comments

Comments
 (0)