@@ -12,7 +12,7 @@ use std::ops::Deref;
1212use std:: sync:: atomic:: { self , AtomicUsize } ;
1313use std:: sync:: Arc ;
1414
15- use self :: hyper:: header:: HeaderValue ;
15+ use self :: hyper:: header:: { HeaderMap , HeaderValue } ;
1616use self :: url:: Url ;
1717use crate :: helpers;
1818use crate :: rpc;
@@ -69,17 +69,50 @@ pub type FetchTask<F> = Response<F, hyper::Chunk>;
6969pub struct Http {
7070 id : Arc < AtomicUsize > ,
7171 url : hyper:: Uri ,
72- basic_auth : Option < HeaderValue > ,
72+ headers : Option < HeaderMap > ,
7373 write_sender : mpsc:: UnboundedSender < ( hyper:: Request < hyper:: Body > , Pending ) > ,
7474}
7575
76+ struct EventLoopParams < ' a , ' b > {
77+ url : & ' a str ,
78+ max_parallel : usize ,
79+ handle : & ' b reactor:: Handle ,
80+ headers : Option < HeaderMap > ,
81+ }
82+
7683impl Http {
7784 /// Create new HTTP transport with given URL and spawn an event loop in a separate thread.
7885 /// NOTE: Dropping event loop handle will stop the transport layer!
7986 pub fn new ( url : & str ) -> Result < ( EventLoopHandle , Self ) > {
8087 Self :: with_max_parallel ( url, DEFAULT_MAX_PARALLEL )
8188 }
8289
90+ /// Create a HTTP transport with the given URL and spawn an event loop in a separate thread.
91+ /// You can provide custom headers to be passed to the HTTP requests.
92+ /// NOTE: Dropping event loop handle will stop the transport layer!
93+ pub fn with_headers ( url : & str , headers : HeaderMap ) -> Result < ( EventLoopHandle , Self ) > {
94+ Self :: with_max_parallel_and_headers ( url, DEFAULT_MAX_PARALLEL , headers)
95+ }
96+ /// Create a HTTP transport with the given URL and spawn an event loop in a separate thread.
97+ /// You can set a maximal number of parallel requests.
98+ /// You can provide custom headers to be passed to the HTTP requests.
99+ /// NOTE: Dropping event loop handle will stop the transport layer!
100+ pub fn with_max_parallel_and_headers (
101+ url : & str ,
102+ max_parallel : usize ,
103+ headers : HeaderMap ,
104+ ) -> Result < ( EventLoopHandle , Self ) > {
105+ let url = url. to_owned ( ) ;
106+ EventLoopHandle :: spawn ( move |handle| {
107+ Self :: with_event_loop_internal ( EventLoopParams {
108+ url : & url,
109+ handle,
110+ max_parallel,
111+ headers : Some ( headers) ,
112+ } )
113+ } )
114+ }
115+
83116 /// Create new HTTP transport with given URL and spawn an event loop in a separate thread.
84117 /// You can set a maximal number of parallel requests using the second parameter.
85118 /// NOTE: Dropping event loop handle will stop the transport layer!
@@ -90,6 +123,22 @@ impl Http {
90123
91124 /// Create new HTTP transport with given URL and existing event loop handle.
92125 pub fn with_event_loop ( url : & str , handle : & reactor:: Handle , max_parallel : usize ) -> Result < Self > {
126+ Self :: with_event_loop_internal ( EventLoopParams {
127+ url,
128+ handle,
129+ max_parallel,
130+ headers : None ,
131+ } )
132+ }
133+
134+ fn with_event_loop_internal ( params : EventLoopParams ) -> Result < Self > {
135+ let EventLoopParams {
136+ url,
137+ handle,
138+ max_parallel,
139+ mut headers,
140+ } = params;
141+
93142 let ( write_sender, write_receiver) = mpsc:: unbounded ( ) ;
94143
95144 #[ cfg( feature = "tls" ) ]
@@ -123,21 +172,30 @@ impl Http {
123172 } ) ,
124173 ) ;
125174
126- let basic_auth = {
127- let url = Url :: parse ( url) ?;
128- let user = url. username ( ) ;
129- let auth = format ! ( "{}:{}" , user, url. password( ) . unwrap_or_default( ) ) ;
130- if & auth == ":" {
131- None
132- } else {
133- Some ( HeaderValue :: from_str ( & format ! ( "Basic {}" , base64:: encode( & auth) ) ) ?)
134- }
135- } ;
175+ // Check if there is basic auth information in the URL
176+ let parsed_url = Url :: parse ( url) ?;
177+ let basic_auth = format ! (
178+ "{}:{}" ,
179+ parsed_url. username( ) ,
180+ parsed_url. password( ) . unwrap_or_default( )
181+ ) ;
182+ if basic_auth != ":" {
183+ // Add Authorization header for basic auth but ONLY if the
184+ // header isn't already present in the provided headers
185+ let basic_auth_header = HeaderValue :: from_str ( & format ! ( "Basic {}" , base64:: encode( & basic_auth) ) ) ?;
186+
187+ headers = Some ( headers. unwrap_or_default ( ) ) . map ( |mut h| {
188+ h. entry ( hyper:: header:: AUTHORIZATION )
189+ . unwrap ( )
190+ . or_insert ( basic_auth_header) ;
191+ h
192+ } ) ;
193+ }
136194
137195 Ok ( Http {
138196 id : Default :: default ( ) ,
139197 url : url. parse ( ) ?,
140- basic_auth ,
198+ headers ,
141199 write_sender,
142200 } )
143201 }
@@ -163,10 +221,9 @@ impl Http {
163221 if len < MAX_SINGLE_CHUNK {
164222 req. headers_mut ( ) . insert ( hyper:: header:: CONTENT_LENGTH , len. into ( ) ) ;
165223 }
166- // Send basic auth header
167- if let Some ( ref basic_auth) = self . basic_auth {
168- req. headers_mut ( )
169- . insert ( hyper:: header:: AUTHORIZATION , basic_auth. clone ( ) ) ;
224+ // Add headers
225+ if let Some ( ref headers) = self . headers {
226+ req. headers_mut ( ) . extend ( headers. clone ( ) )
170227 }
171228 let ( tx, rx) = futures:: oneshot ( ) ;
172229 let result = self
@@ -229,15 +286,15 @@ fn single_response<T: Deref<Target = [u8]>>(response: T) -> Result<rpc::Value> {
229286/// Parse bytes RPC batch response into `Result`.
230287fn batch_response < T : Deref < Target = [ u8 ] > > ( response : T ) -> Result < Vec < Result < rpc:: Value > > > {
231288 // See comment in `single_response`.
232- let mut json: Vec < serde_json:: Value > =
289+ let mut json: Vec < serde_json:: Value > =
233290 serde_json:: from_slice ( & * response) . map_err ( |e| Error :: InvalidResponse ( format ! ( "{:?}" , e) ) ) ?;
234291 for value in & mut json {
235292 if let Some ( id) = value. get_mut ( "id" ) {
236293 id. take ( ) ;
237294 }
238295 }
239296 let response = serde_json:: from_value ( serde_json:: Value :: Array ( json) )
240- . map_err ( |e| Error :: InvalidResponse ( format ! ( "{:?}" , e) ) ) ?;
297+ . map_err ( |e| Error :: InvalidResponse ( format ! ( "{:?}" , e) ) ) ?;
241298 match response {
242299 rpc:: Response :: Batch ( outputs) => Ok ( outputs. into_iter ( ) . map ( helpers:: to_result_from_output) . collect ( ) ) ,
243300 _ => Err ( Error :: InvalidResponse ( "Expected batch, got single." . into ( ) ) ) ,
@@ -251,47 +308,87 @@ mod tests {
251308 #[ test]
252309 fn http_supports_basic_auth_with_user_and_password ( ) {
253310 let http =
Http :: new ( "https://user:[email protected] :8545" ) ; 311+
312+ let mut expected_headers = HeaderMap :: new ( ) ;
313+ expected_headers. insert (
314+ hyper:: header:: AUTHORIZATION ,
315+ HeaderValue :: from_static ( "Basic dXNlcjpwYXNzd29yZA==" ) ,
316+ ) ;
317+
254318 assert ! ( http. is_ok( ) ) ;
319+
255320 match http {
256- Ok ( ( _, transport) ) => {
257- assert ! ( transport. basic_auth. is_some( ) ) ;
258- assert_eq ! (
259- transport. basic_auth,
260- Some ( HeaderValue :: from_static( "Basic dXNlcjpwYXNzd29yZA==" ) )
261- )
262- }
321+ Ok ( ( _, transport) ) => assert_eq ! ( transport. headers, Some ( expected_headers) ) ,
263322 Err ( _) => assert ! ( false , "" ) ,
264323 }
265324 }
266325
267326 #[ test]
268327 fn http_supports_basic_auth_with_user_no_password ( ) {
269328 let http = Http :: new ( "https://username:@127.0.0.1:8545" ) ;
329+
330+ let mut expected_headers = HeaderMap :: new ( ) ;
331+ expected_headers. insert (
332+ hyper:: header:: AUTHORIZATION ,
333+ HeaderValue :: from_static ( "Basic dXNlcm5hbWU6" ) ,
334+ ) ;
335+
270336 assert ! ( http. is_ok( ) ) ;
337+
271338 match http {
272- Ok ( ( _, transport) ) => {
273- assert ! ( transport. basic_auth. is_some( ) ) ;
274- assert_eq ! (
275- transport. basic_auth,
276- Some ( HeaderValue :: from_static( "Basic dXNlcm5hbWU6" ) )
277- )
278- }
339+ Ok ( ( _, transport) ) => assert_eq ! ( transport. headers, Some ( expected_headers) ) ,
279340 Err ( _) => assert ! ( false , "" ) ,
280341 }
281342 }
282343
283344 #[ test]
284345 fn http_supports_basic_auth_with_only_password ( ) {
285346 let http =
Http :: new ( "https://:[email protected] :8545" ) ; 347+
348+ let mut expected_headers = HeaderMap :: new ( ) ;
349+ expected_headers. insert (
350+ hyper:: header:: AUTHORIZATION ,
351+ HeaderValue :: from_static ( "Basic OnBhc3N3b3Jk" ) ,
352+ ) ;
353+
354+ assert ! ( http. is_ok( ) ) ;
355+ match http {
356+ Ok ( ( _, transport) ) => assert_eq ! ( transport. headers, Some ( expected_headers) ) ,
357+ Err ( _) => assert ! ( false , "" ) ,
358+ }
359+ }
360+
361+ #[ test]
362+ fn http_supports_custom_headers ( ) {
363+ let mut expected_headers = HeaderMap :: new ( ) ;
364+ expected_headers. insert (
365+ hyper:: header:: CONTENT_TYPE ,
366+ HeaderValue :: from_static ( "application/json" ) ,
367+ ) ;
368+
369+ let http = Http :: with_headers ( "https://127.0.0.1:8545" , expected_headers. clone ( ) ) ;
370+
371+ assert ! ( http. is_ok( ) ) ;
372+ match http {
373+ Ok ( ( _, transport) ) => assert_eq ! ( transport. headers, Some ( expected_headers) ) ,
374+ Err ( _) => assert ! ( false , "" ) ,
375+ }
376+ }
377+
378+ #[ test]
379+ fn http_basic_auth_does_not_override_authorization_header ( ) {
380+ let mut expected_headers = HeaderMap :: new ( ) ;
381+ expected_headers. insert ( hyper:: header:: AUTHORIZATION , HeaderValue :: from_static ( "Bearer foo" ) ) ;
382+ expected_headers. insert (
383+ hyper:: header:: CONTENT_TYPE ,
384+ HeaderValue :: from_static ( "application/json" ) ,
385+ ) ;
386+
387+ let http =
Http :: with_headers ( "https://username:[email protected] :8545" , expected_headers
. clone ( ) ) ; 388+
286389 assert ! ( http. is_ok( ) ) ;
287390 match http {
288- Ok ( ( _, transport) ) => {
289- assert ! ( transport. basic_auth. is_some( ) ) ;
290- assert_eq ! (
291- transport. basic_auth,
292- Some ( HeaderValue :: from_static( "Basic OnBhc3N3b3Jk" ) )
293- )
294- }
391+ Ok ( ( _, transport) ) => assert_eq ! ( transport. headers, Some ( expected_headers) ) ,
295392 Err ( _) => assert ! ( false , "" ) ,
296393 }
297394 }
0 commit comments