Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 88bab87

Browse files
committed
Add custom headers support to HTTP transport
1 parent 40f9cb5 commit 88bab87

File tree

1 file changed

+137
-40
lines changed

1 file changed

+137
-40
lines changed

src/transports/http.rs

Lines changed: 137 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use std::ops::Deref;
1212
use std::sync::atomic::{self, AtomicUsize};
1313
use std::sync::Arc;
1414

15-
use self::hyper::header::HeaderValue;
15+
use self::hyper::header::{HeaderMap, HeaderValue};
1616
use self::url::Url;
1717
use crate::helpers;
1818
use crate::rpc;
@@ -69,17 +69,50 @@ pub type FetchTask<F> = Response<F, hyper::Chunk>;
6969
pub struct Http {
7070
id: Arc<AtomicUsize>,
7171
url: hyper::Uri,
72-
basic_auth: Option<HeaderValue>,
72+
headers: Option<HeaderMap>,
7373
write_sender: mpsc::UnboundedSender<(hyper::Request<hyper::Body>, Pending)>,
7474
}
7575

76+
struct EventLoopParams<'a, 'b> {
77+
url: &'a str,
78+
max_parallel: usize,
79+
handle: &'b reactor::Handle,
80+
headers: Option<HeaderMap>,
81+
}
82+
7683
impl Http {
7784
/// Create new HTTP transport with given URL and spawn an event loop in a separate thread.
7885
/// NOTE: Dropping event loop handle will stop the transport layer!
7986
pub fn new(url: &str) -> Result<(EventLoopHandle, Self)> {
8087
Self::with_max_parallel(url, DEFAULT_MAX_PARALLEL)
8188
}
8289

90+
/// Create a HTTP transport with the given URL and spawn an event loop in a separate thread.
91+
/// You can provide custom headers to be passed to the HTTP requests.
92+
/// NOTE: Dropping event loop handle will stop the transport layer!
93+
pub fn with_headers(url: &str, headers: HeaderMap) -> Result<(EventLoopHandle, Self)> {
94+
Self::with_max_parallel_and_headers(url, DEFAULT_MAX_PARALLEL, headers)
95+
}
96+
/// Create a HTTP transport with the given URL and spawn an event loop in a separate thread.
97+
/// You can set a maximal number of parallel requests.
98+
/// You can provide custom headers to be passed to the HTTP requests.
99+
/// NOTE: Dropping event loop handle will stop the transport layer!
100+
pub fn with_max_parallel_and_headers(
101+
url: &str,
102+
max_parallel: usize,
103+
headers: HeaderMap,
104+
) -> Result<(EventLoopHandle, Self)> {
105+
let url = url.to_owned();
106+
EventLoopHandle::spawn(move |handle| {
107+
Self::with_event_loop_internal(EventLoopParams {
108+
url: &url,
109+
handle,
110+
max_parallel,
111+
headers: Some(headers),
112+
})
113+
})
114+
}
115+
83116
/// Create new HTTP transport with given URL and spawn an event loop in a separate thread.
84117
/// You can set a maximal number of parallel requests using the second parameter.
85118
/// NOTE: Dropping event loop handle will stop the transport layer!
@@ -90,6 +123,22 @@ impl Http {
90123

91124
/// Create new HTTP transport with given URL and existing event loop handle.
92125
pub fn with_event_loop(url: &str, handle: &reactor::Handle, max_parallel: usize) -> Result<Self> {
126+
Self::with_event_loop_internal(EventLoopParams {
127+
url,
128+
handle,
129+
max_parallel,
130+
headers: None,
131+
})
132+
}
133+
134+
fn with_event_loop_internal(params: EventLoopParams) -> Result<Self> {
135+
let EventLoopParams {
136+
url,
137+
handle,
138+
max_parallel,
139+
mut headers,
140+
} = params;
141+
93142
let (write_sender, write_receiver) = mpsc::unbounded();
94143

95144
#[cfg(feature = "tls")]
@@ -123,21 +172,30 @@ impl Http {
123172
}),
124173
);
125174

126-
let basic_auth = {
127-
let url = Url::parse(url)?;
128-
let user = url.username();
129-
let auth = format!("{}:{}", user, url.password().unwrap_or_default());
130-
if &auth == ":" {
131-
None
132-
} else {
133-
Some(HeaderValue::from_str(&format!("Basic {}", base64::encode(&auth)))?)
134-
}
135-
};
175+
// Check if there is basic auth information in the URL
176+
let parsed_url = Url::parse(url)?;
177+
let basic_auth = format!(
178+
"{}:{}",
179+
parsed_url.username(),
180+
parsed_url.password().unwrap_or_default()
181+
);
182+
if basic_auth != ":" {
183+
// Add Authorization header for basic auth but ONLY if the
184+
// header isn't already present in the provided headers
185+
let basic_auth_header = HeaderValue::from_str(&format!("Basic {}", base64::encode(&basic_auth)))?;
186+
187+
headers = Some(headers.unwrap_or_default()).map(|mut h| {
188+
h.entry(hyper::header::AUTHORIZATION)
189+
.unwrap()
190+
.or_insert(basic_auth_header);
191+
h
192+
});
193+
}
136194

137195
Ok(Http {
138196
id: Default::default(),
139197
url: url.parse()?,
140-
basic_auth,
198+
headers,
141199
write_sender,
142200
})
143201
}
@@ -163,10 +221,9 @@ impl Http {
163221
if len < MAX_SINGLE_CHUNK {
164222
req.headers_mut().insert(hyper::header::CONTENT_LENGTH, len.into());
165223
}
166-
// Send basic auth header
167-
if let Some(ref basic_auth) = self.basic_auth {
168-
req.headers_mut()
169-
.insert(hyper::header::AUTHORIZATION, basic_auth.clone());
224+
// Add headers
225+
if let Some(ref headers) = self.headers {
226+
req.headers_mut().extend(headers.clone())
170227
}
171228
let (tx, rx) = futures::oneshot();
172229
let result = self
@@ -229,15 +286,15 @@ fn single_response<T: Deref<Target = [u8]>>(response: T) -> Result<rpc::Value> {
229286
/// Parse bytes RPC batch response into `Result`.
230287
fn batch_response<T: Deref<Target = [u8]>>(response: T) -> Result<Vec<Result<rpc::Value>>> {
231288
// See comment in `single_response`.
232-
let mut json: Vec<serde_json::Value> =
289+
let mut json: Vec<serde_json::Value> =
233290
serde_json::from_slice(&*response).map_err(|e| Error::InvalidResponse(format!("{:?}", e)))?;
234291
for value in &mut json {
235292
if let Some(id) = value.get_mut("id") {
236293
id.take();
237294
}
238295
}
239296
let response = serde_json::from_value(serde_json::Value::Array(json))
240-
.map_err(|e| Error::InvalidResponse(format!("{:?}", e)))?;
297+
.map_err(|e| Error::InvalidResponse(format!("{:?}", e)))?;
241298
match response {
242299
rpc::Response::Batch(outputs) => Ok(outputs.into_iter().map(helpers::to_result_from_output).collect()),
243300
_ => Err(Error::InvalidResponse("Expected batch, got single.".into())),
@@ -251,47 +308,87 @@ mod tests {
251308
#[test]
252309
fn http_supports_basic_auth_with_user_and_password() {
253310
let http = Http::new("https://user:[email protected]:8545");
311+
312+
let mut expected_headers = HeaderMap::new();
313+
expected_headers.insert(
314+
hyper::header::AUTHORIZATION,
315+
HeaderValue::from_static("Basic dXNlcjpwYXNzd29yZA=="),
316+
);
317+
254318
assert!(http.is_ok());
319+
255320
match http {
256-
Ok((_, transport)) => {
257-
assert!(transport.basic_auth.is_some());
258-
assert_eq!(
259-
transport.basic_auth,
260-
Some(HeaderValue::from_static("Basic dXNlcjpwYXNzd29yZA=="))
261-
)
262-
}
321+
Ok((_, transport)) => assert_eq!(transport.headers, Some(expected_headers)),
263322
Err(_) => assert!(false, ""),
264323
}
265324
}
266325

267326
#[test]
268327
fn http_supports_basic_auth_with_user_no_password() {
269328
let http = Http::new("https://username:@127.0.0.1:8545");
329+
330+
let mut expected_headers = HeaderMap::new();
331+
expected_headers.insert(
332+
hyper::header::AUTHORIZATION,
333+
HeaderValue::from_static("Basic dXNlcm5hbWU6"),
334+
);
335+
270336
assert!(http.is_ok());
337+
271338
match http {
272-
Ok((_, transport)) => {
273-
assert!(transport.basic_auth.is_some());
274-
assert_eq!(
275-
transport.basic_auth,
276-
Some(HeaderValue::from_static("Basic dXNlcm5hbWU6"))
277-
)
278-
}
339+
Ok((_, transport)) => assert_eq!(transport.headers, Some(expected_headers)),
279340
Err(_) => assert!(false, ""),
280341
}
281342
}
282343

283344
#[test]
284345
fn http_supports_basic_auth_with_only_password() {
285346
let http = Http::new("https://:[email protected]:8545");
347+
348+
let mut expected_headers = HeaderMap::new();
349+
expected_headers.insert(
350+
hyper::header::AUTHORIZATION,
351+
HeaderValue::from_static("Basic OnBhc3N3b3Jk"),
352+
);
353+
354+
assert!(http.is_ok());
355+
match http {
356+
Ok((_, transport)) => assert_eq!(transport.headers, Some(expected_headers)),
357+
Err(_) => assert!(false, ""),
358+
}
359+
}
360+
361+
#[test]
362+
fn http_supports_custom_headers() {
363+
let mut expected_headers = HeaderMap::new();
364+
expected_headers.insert(
365+
hyper::header::CONTENT_TYPE,
366+
HeaderValue::from_static("application/json"),
367+
);
368+
369+
let http = Http::with_headers("https://127.0.0.1:8545", expected_headers.clone());
370+
371+
assert!(http.is_ok());
372+
match http {
373+
Ok((_, transport)) => assert_eq!(transport.headers, Some(expected_headers)),
374+
Err(_) => assert!(false, ""),
375+
}
376+
}
377+
378+
#[test]
379+
fn http_basic_auth_does_not_override_authorization_header() {
380+
let mut expected_headers = HeaderMap::new();
381+
expected_headers.insert(hyper::header::AUTHORIZATION, HeaderValue::from_static("Bearer foo"));
382+
expected_headers.insert(
383+
hyper::header::CONTENT_TYPE,
384+
HeaderValue::from_static("application/json"),
385+
);
386+
387+
let http = Http::with_headers("https://username:[email protected]:8545", expected_headers.clone());
388+
286389
assert!(http.is_ok());
287390
match http {
288-
Ok((_, transport)) => {
289-
assert!(transport.basic_auth.is_some());
290-
assert_eq!(
291-
transport.basic_auth,
292-
Some(HeaderValue::from_static("Basic OnBhc3N3b3Jk"))
293-
)
294-
}
391+
Ok((_, transport)) => assert_eq!(transport.headers, Some(expected_headers)),
295392
Err(_) => assert!(false, ""),
296393
}
297394
}

0 commit comments

Comments
 (0)