@@ -19,11 +19,11 @@ use bytes::Buf;
19
19
use crypto_common:: InvalidLength ;
20
20
use hmac:: { Hmac , Mac } ;
21
21
#[ cfg( feature = "logging" ) ]
22
- use log:: { Level , error, log_enabled , trace} ;
22
+ use log:: { error, trace} ;
23
23
#[ cfg( feature = "metrics" ) ]
24
24
use prometheus:: { CounterVec , opts, register_counter_vec} ;
25
25
use reqwest:: {
26
- Method , StatusCode ,
26
+ IntoUrl , Method , StatusCode ,
27
27
header:: { self , HeaderValue } ,
28
28
} ;
29
29
use serde:: { Deserialize , Serialize , de:: DeserializeOwned } ;
@@ -71,40 +71,26 @@ static CLIENT_REQUEST_DURATION: LazyLock<CounterVec> = LazyLock::new(|| {
71
71
type HmacSha512 = Hmac < Sha512 > ;
72
72
73
73
// -----------------------------------------------------------------------------
74
- // Request trait
75
-
76
- pub trait Request {
77
- type Error ;
74
+ // RestClient trait
78
75
76
+ pub trait RestClient < X > : Execute {
79
77
fn request < T , U > (
80
78
& self ,
81
79
method : & Method ,
82
- endpoint : & str ,
80
+ endpoint : X ,
83
81
payload : & T ,
84
- ) -> impl Future < Output = Result < U , Self :: Error > > + Send
82
+ ) -> impl Future < Output = Result < U , < Self as Execute > :: Error > > + Send
85
83
where
86
84
T : ?Sized + Serialize + Debug + Send + Sync ,
87
85
U : DeserializeOwned + Debug + Send + Sync ;
88
86
89
- fn execute (
90
- & self ,
91
- request : reqwest:: Request ,
92
- ) -> impl Future < Output = Result < reqwest:: Response , Self :: Error > > + Send + ' static ;
93
- }
94
-
95
- // -----------------------------------------------------------------------------
96
- // RestClient trait
97
-
98
- pub trait RestClient : Debug {
99
- type Error ;
100
-
101
- fn get < T > ( & self , endpoint : & str ) -> impl Future < Output = Result < T , Self :: Error > > + Send
87
+ fn get < T > ( & self , endpoint : X ) -> impl Future < Output = Result < T , Self :: Error > > + Send
102
88
where
103
89
T : DeserializeOwned + Debug + Send + Sync ;
104
90
105
91
fn post < T , U > (
106
92
& self ,
107
- endpoint : & str ,
93
+ endpoint : X ,
108
94
payload : & T ,
109
95
) -> impl Future < Output = Result < U , Self :: Error > > + Send
110
96
where
@@ -113,7 +99,7 @@ pub trait RestClient: Debug {
113
99
114
100
fn put < T , U > (
115
101
& self ,
116
- endpoint : & str ,
102
+ endpoint : X ,
117
103
payload : & T ,
118
104
) -> impl Future < Output = Result < U , Self :: Error > > + Send
119
105
where
@@ -122,20 +108,20 @@ pub trait RestClient: Debug {
122
108
123
109
fn patch < T , U > (
124
110
& self ,
125
- endpoint : & str ,
111
+ endpoint : X ,
126
112
payload : & T ,
127
113
) -> impl Future < Output = Result < U , Self :: Error > > + Send
128
114
where
129
115
T : ?Sized + Serialize + Debug + Send + Sync ,
130
116
U : DeserializeOwned + Debug + Send + Sync ;
131
117
132
- fn delete ( & self , endpoint : & str ) -> impl Future < Output = Result < ( ) , Self :: Error > > + Send ;
118
+ fn delete ( & self , endpoint : X ) -> impl Future < Output = Result < ( ) , Self :: Error > > + Send ;
133
119
}
134
120
135
121
// -----------------------------------------------------------------------------
136
122
// ClientCredentials structure
137
123
138
- #[ derive( Serialize , Deserialize , PartialEq , Eq , Clone , Debug ) ]
124
+ #[ derive( Serialize , Deserialize , PartialEq , Eq , Clone ) ]
139
125
#[ serde( untagged) ]
140
126
pub enum Credentials {
141
127
OAuth1 {
@@ -160,6 +146,17 @@ pub enum Credentials {
160
146
} ,
161
147
}
162
148
149
+ impl fmt:: Debug for Credentials {
150
+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> fmt:: Result {
151
+ // NOTE: ensure secrets are not leaked in logs
152
+ match self {
153
+ Self :: OAuth1 { .. } => f. write_str ( "OAuth1" ) ,
154
+ Self :: Basic { .. } => f. write_str ( "Basic" ) ,
155
+ Self :: Bearer { .. } => f. write_str ( "Bearer" ) ,
156
+ }
157
+ }
158
+ }
159
+
163
160
impl Default for Credentials {
164
161
fn default ( ) -> Self {
165
162
Self :: OAuth1 {
@@ -335,19 +332,16 @@ impl OAuth1 for Signer {
335
332
336
333
if !query. is_empty ( ) {
337
334
for qparam in query. split ( '&' ) {
338
- let ( k, v) = qparam. split_at ( qparam . find ( '=' ) . ok_or_else ( || {
335
+ let ( k, v) = qparam. split_once ( "=" ) . ok_or_else ( || {
339
336
SignerError :: Parse ( format ! ( "failed to parse query parameter, {qparam}" ) )
340
- } ) ?) ;
341
-
342
- if !params. contains_key ( k) {
343
- params. insert ( k. to_string ( ) , v. strip_prefix ( '=' ) . unwrap_or ( v) . to_owned ( ) ) ;
344
- }
337
+ } ) ?;
338
+ params. entry ( k. to_owned ( ) ) . or_insert ( v. to_owned ( ) ) ;
345
339
}
346
340
}
347
341
348
342
let mut params = params
349
343
. iter ( )
350
- . map ( |( k, v) | format ! ( "{}={}" , k , urlencoding:: encode( v) ) )
344
+ . map ( |( k, v) | format ! ( "{k }={}" , urlencoding:: encode( v) ) )
351
345
. collect :: < Vec < _ > > ( ) ;
352
346
353
347
params. sort ( ) ;
@@ -429,8 +423,6 @@ pub enum ClientError {
429
423
Digest ( SignerError ) ,
430
424
#[ error( "failed to serialize signature as header value, {0}" ) ]
431
425
SerializeHeaderValue ( header:: InvalidHeaderValue ) ,
432
- #[ error( "failed to parse url endpoint, {0}" ) ]
433
- ParseUrlEndpoint ( url:: ParseError ) ,
434
426
}
435
427
436
428
// -----------------------------------------------------------------------------
@@ -446,69 +438,24 @@ pub struct Client {
446
438
credentials : Option < Credentials > ,
447
439
}
448
440
449
- impl Request for Client {
450
- type Error = ClientError ;
441
+ pub trait Execute {
442
+ type Error ;
451
443
452
- #[ cfg_attr( feature = "tracing" , tracing:: instrument) ]
453
- async fn request < T , U > (
444
+ fn execute (
454
445
& self ,
455
- method : & Method ,
456
- endpoint : & str ,
457
- payload : & T ,
458
- ) -> Result < U , Self :: Error >
459
- where
460
- T : ?Sized + Serialize + Debug + Send + Sync ,
461
- U : DeserializeOwned + Debug + Send + Sync ,
462
- {
463
- let buf = serde_json:: to_vec ( payload) . map_err ( ClientError :: Serialize ) ?;
464
- let mut request = reqwest:: Request :: new (
465
- method. to_owned ( ) ,
466
- endpoint. parse ( ) . map_err ( ClientError :: ParseUrlEndpoint ) ?,
467
- ) ;
468
-
469
- request
470
- . headers_mut ( )
471
- . insert ( header:: CONTENT_TYPE , APPLICATION_JSON ) ;
472
-
473
- request
474
- . headers_mut ( )
475
- . insert ( header:: CONTENT_LENGTH , HeaderValue :: from ( buf. len ( ) ) ) ;
476
-
477
- request. headers_mut ( ) . insert ( header:: ACCEPT_CHARSET , UTF8 ) ;
478
-
479
- request
480
- . headers_mut ( )
481
- . insert ( header:: ACCEPT , APPLICATION_JSON ) ;
482
-
483
- * request. body_mut ( ) = Some ( buf. into ( ) ) ;
484
-
485
- let res = self . execute ( request) . await ?;
486
- let status = res. status ( ) ;
487
- let buf = res. bytes ( ) . await . map_err ( ClientError :: BodyAggregation ) ?;
488
-
489
- #[ cfg( feature = "logging" ) ]
490
- if log_enabled ! ( Level :: Trace ) {
491
- trace ! (
492
- "received response, endpoint: '{endpoint}', method: '{method}', status: '{}'" ,
493
- status. as_u16( )
494
- ) ;
495
- }
496
-
497
- if !status. is_success ( ) {
498
- return Err ( ClientError :: StatusCode (
499
- status,
500
- serde_json:: from_reader ( buf. reader ( ) ) . map_err ( ClientError :: Deserialize ) ?,
501
- ) ) ;
502
- }
446
+ request : reqwest:: Request ,
447
+ ) -> impl Future < Output = Result < reqwest:: Response , Self :: Error > > + Send + ' static ;
448
+ }
503
449
504
- serde_json :: from_reader ( buf . reader ( ) ) . map_err ( ClientError :: Deserialize )
505
- }
450
+ impl Execute for Client {
451
+ type Error = ClientError ;
506
452
453
+ /// Executes the given HTTP `request`.
507
454
#[ cfg_attr( feature = "tracing" , tracing:: instrument( skip( self ) ) ) ]
508
455
fn execute (
509
456
& self ,
510
457
mut request : reqwest:: Request ,
511
- ) -> impl Future < Output = Result < reqwest:: Response , Self :: Error > > + ' static {
458
+ ) -> impl Future < Output = Result < reqwest:: Response , Self :: Error > > + Send + ' static {
512
459
let client = self . clone ( ) ;
513
460
514
461
async move {
@@ -582,18 +529,63 @@ impl Request for Client {
582
529
}
583
530
}
584
531
585
- impl RestClient for Client {
586
- type Error = ClientError ;
532
+ impl < X : IntoUrl + fmt:: Debug + Send > RestClient < X > for Client {
533
+ #[ cfg_attr( feature = "tracing" , tracing:: instrument) ]
534
+ async fn request < T , U > (
535
+ & self ,
536
+ method : & Method ,
537
+ endpoint : X ,
538
+ payload : & T ,
539
+ ) -> Result < U , Self :: Error >
540
+ where
541
+ T : ?Sized + Serialize + Debug + Send + Sync ,
542
+ U : DeserializeOwned + Debug + Send + Sync ,
543
+ {
544
+ let buf = serde_json:: to_vec ( payload) . map_err ( ClientError :: Serialize ) ?;
545
+
546
+ let url = endpoint. into_url ( ) . map_err ( ClientError :: Request ) ?;
547
+
548
+ #[ cfg( feature = "logging" ) ]
549
+ let endpoint = url. as_str ( ) . to_owned ( ) ;
550
+
551
+ let mut request = reqwest:: Request :: new ( method. to_owned ( ) , url) ;
552
+
553
+ let headers = request. headers_mut ( ) ;
554
+ headers. insert ( header:: CONTENT_TYPE , APPLICATION_JSON ) ;
555
+ headers. insert ( header:: CONTENT_LENGTH , HeaderValue :: from ( buf. len ( ) ) ) ;
556
+ headers. insert ( header:: ACCEPT_CHARSET , UTF8 ) ;
557
+ headers. insert ( header:: ACCEPT , APPLICATION_JSON ) ;
558
+
559
+ * request. body_mut ( ) = Some ( buf. into ( ) ) ;
560
+
561
+ let res = self . execute ( request) . await ?;
562
+ let status = res. status ( ) ;
563
+ let buf = res. bytes ( ) . await . map_err ( ClientError :: BodyAggregation ) ?;
564
+
565
+ #[ cfg( feature = "logging" ) ]
566
+ trace ! (
567
+ "received response, endpoint: '{endpoint}', method: '{method}', status: '{}'" ,
568
+ status. as_u16( )
569
+ ) ;
570
+
571
+ if !status. is_success ( ) {
572
+ return Err ( ClientError :: StatusCode (
573
+ status,
574
+ serde_json:: from_reader ( buf. reader ( ) ) . map_err ( ClientError :: Deserialize ) ?,
575
+ ) ) ;
576
+ }
577
+
578
+ serde_json:: from_reader ( buf. reader ( ) ) . map_err ( ClientError :: Deserialize )
579
+ }
587
580
588
581
#[ cfg_attr( feature = "tracing" , tracing:: instrument) ]
589
- async fn get < T > ( & self , endpoint : & str ) -> Result < T , Self :: Error >
582
+ async fn get < T > ( & self , endpoint : X ) -> Result < T , Self :: Error >
590
583
where
591
584
T : DeserializeOwned + Debug + Send + Sync ,
592
585
{
593
- let mut req = reqwest:: Request :: new (
594
- Method :: GET ,
595
- endpoint. parse ( ) . map_err ( ClientError :: ParseUrlEndpoint ) ?,
596
- ) ;
586
+ let url = endpoint. into_url ( ) . map_err ( ClientError :: Request ) ?;
587
+
588
+ let mut req = reqwest:: Request :: new ( Method :: GET , url) ;
597
589
598
590
req. headers_mut ( ) . insert ( header:: ACCEPT_CHARSET , UTF8 ) ;
599
591
@@ -614,7 +606,7 @@ impl RestClient for Client {
614
606
}
615
607
616
608
#[ cfg_attr( feature = "tracing" , tracing:: instrument) ]
617
- async fn post < T , U > ( & self , endpoint : & str , payload : & T ) -> Result < U , Self :: Error >
609
+ async fn post < T , U > ( & self , endpoint : X , payload : & T ) -> Result < U , Self :: Error >
618
610
where
619
611
T : ?Sized + Serialize + Debug + Send + Sync ,
620
612
U : DeserializeOwned + Debug + Send + Sync ,
@@ -623,7 +615,7 @@ impl RestClient for Client {
623
615
}
624
616
625
617
#[ cfg_attr( feature = "tracing" , tracing:: instrument) ]
626
- async fn put < T , U > ( & self , endpoint : & str , payload : & T ) -> Result < U , Self :: Error >
618
+ async fn put < T , U > ( & self , endpoint : X , payload : & T ) -> Result < U , Self :: Error >
627
619
where
628
620
T : ?Sized + Serialize + Debug + Send + Sync ,
629
621
U : DeserializeOwned + Debug + Send + Sync ,
@@ -632,7 +624,7 @@ impl RestClient for Client {
632
624
}
633
625
634
626
#[ cfg_attr( feature = "tracing" , tracing:: instrument) ]
635
- async fn patch < T , U > ( & self , endpoint : & str , payload : & T ) -> Result < U , Self :: Error >
627
+ async fn patch < T , U > ( & self , endpoint : X , payload : & T ) -> Result < U , Self :: Error >
636
628
where
637
629
T : ?Sized + Serialize + Debug + Send + Sync ,
638
630
U : DeserializeOwned + Debug + Send + Sync ,
@@ -641,11 +633,9 @@ impl RestClient for Client {
641
633
}
642
634
643
635
#[ cfg_attr( feature = "tracing" , tracing:: instrument) ]
644
- async fn delete ( & self , endpoint : & str ) -> Result < ( ) , Self :: Error > {
645
- let req = reqwest:: Request :: new (
646
- Method :: DELETE ,
647
- endpoint. parse ( ) . map_err ( ClientError :: ParseUrlEndpoint ) ?,
648
- ) ;
636
+ async fn delete ( & self , endpoint : X ) -> Result < ( ) , Self :: Error > {
637
+ let url = endpoint. into_url ( ) . map_err ( ClientError :: Request ) ?;
638
+ let req = reqwest:: Request :: new ( Method :: DELETE , url) ;
649
639
650
640
let res = self . execute ( req) . await ?;
651
641
let status = res. status ( ) ;
0 commit comments