diff --git a/Cargo.toml b/Cargo.toml index 8d65061..983aa42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "generic-async-http-client" -version = "0.4.0" +version = "0.5.0" authors = ["User65k <15049544+User65k@users.noreply.github.com>"] edition = "2021" @@ -15,8 +15,8 @@ async-std = {version="1.9.0",optional=true} async-h1 = {version="2.3",optional=true} http-types = {version="2.11",optional=true} -hyper = { version = "0.14", optional=true } -serde_qs = { version ="0.12.0", optional=true } +hyper = { version = "1.2", optional=true } +serde_qs = { version ="0.12", optional=true } serde_urlencoded = { version ="0.7", optional=true } serde_json = {version="1.0",optional=true} tokio = {version = "1.6", optional=true} @@ -32,22 +32,23 @@ log = "0.4" serde = "1.0" #pin-project = "1.0" -futures-rustls = {version="0.24.0",optional=true} -tokio-rustls = { version = "0.24.1", optional = true } -webpki-roots = {version="0.25.2",optional=true} +futures-rustls = {version="0.25.0",optional=true} +tokio-rustls = { version = "0.25.0", optional = true } +webpki-roots = {version="0.26.0",optional=true} #rustls-native-certs -async-native-tls = { version = "0.5.0", default-features = false, optional = true } +async-native-tls = { version = "0.5", default-features = false, optional = true } -cookie_store = { version = "0.20.0", optional = true } +cookie_store = { version = "0.21.0", optional = true } +async-trait = { version = "0.1", optional = true } [features] -use_hyper = ["tokio/net", "hyper/http1", "hyper/http2", "hyper/client", "hyper/runtime", "serde_qs", "serde_urlencoded","serde_json"] +use_hyper = ["tokio/net", "tokio/rt", "hyper/http1", "hyper/http2", "hyper/client", "serde_qs", "serde_urlencoded","serde_json"] use_async_h1 = ["async-std", "async-h1", "http-types"] use_web_sys = ["web-sys", "wasm-bindgen", "wasm-bindgen-futures", "js-sys"] cookies = ["cookie_store"] -proxies = [] +proxies = ["async-trait"] rustls = ["futures-rustls", "tokio-rustls", "webpki-roots"] async_native_tls = ["use_async_h1","async-native-tls/runtime-async-std"] diff --git a/src/hyper/connector.rs b/src/hyper/connector.rs index 3e6f728..1111ac3 100644 --- a/src/hyper/connector.rs +++ b/src/hyper/connector.rs @@ -1,68 +1,82 @@ use hyper::{ - http::uri::{Scheme, Uri}, - service::Service, + header::{HeaderValue, HOST}, http::uri::{Scheme, Uri} }; -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use std::io; - use crate::tcp::Stream; -#[derive(Clone)] -pub struct Connector; - -impl Connector { - pub fn new() -> Connector { - Connector {} - } +async fn connect_to_uri(dst: &Uri) -> Result { + let tls = match dst.scheme_str() { + Some("https") => true, + Some("http") => false, + _ => { + return Err(super::Error::Scheme) + } + }; + let host = match dst.host() { + Some(s) => s, + None => { + return Err(hyper::http::uri::Authority::try_from("]").unwrap_err().into()); + } + }; + let port = match dst.port() { + Some(port) => port.as_u16(), + None => { + if dst.scheme() == Some(&Scheme::HTTPS) { + 443 + } else { + 80 + } + } + }; + Stream::connect(host, port, tls).await.map_err(|e|e.into()) } -impl Service for Connector { - type Response = Stream; - type Error = std::io::Error; - // We can't "name" an `async` generated future. - type Future = Pin> + Send>>; +#[derive(Debug, Clone, Default)] +pub enum HyperClient { + #[default] + New,/* + H1(hyper::client::conn::http1::SendRequest), + TlsH1(), + TlsH2(),*/ +} - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - // This connector is always ready, but others might not be. - Poll::Ready(Ok(())) - } +fn origin_form(uri: &mut Uri) { + let path = match uri.path_and_query() { + Some(path) if path.as_str() != "/" => { + let mut parts = hyper::http::uri::Parts::default(); + parts.path_and_query = Some(path.clone()); + Uri::from_parts(parts).expect("path is valid uri") + } + _none_or_just_slash => { + debug_assert!(Uri::default() == "/"); + Uri::default() + } + }; + *uri = path +} - fn call(&mut self, dst: Uri) -> Self::Future { - let fut = async move { - let tls = match dst.scheme_str() { - Some("https") => true, - Some("http") => false, - _ => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "scheme must be http or https", - )) - } - }; - let host = match dst.host() { - Some(s) => s, - None => { - return Err(io::Error::new(io::ErrorKind::InvalidInput, "missing host")); - } - }; - let port = match dst.port() { - Some(port) => port.as_u16(), - None => { - if dst.scheme() == Some(&Scheme::HTTPS) { - 443 - } else { - 80 - } - } - }; - Stream::connect(host, port, tls).await - }; +impl HyperClient { + pub async fn request(&mut self, mut req: super::Request) -> Result, super::Error> { + let io = connect_to_uri(req.uri()).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + let uri = req.uri().clone(); + req.headers_mut().entry(HOST).or_insert_with(|| { + let hostname = uri.host().expect("authority implies host"); + if let Some(port) = uri.port() { + let s = format!("{}:{}", hostname, port); + HeaderValue::from_str(&s) + } else { + HeaderValue::from_str(hostname) + } + .expect("uri host is valid header value") + }); - Box::pin(fut) + origin_form(req.uri_mut()); + + sender.send_request(req).await.map_err(|e|e.into()) } } diff --git a/src/hyper/mod.rs b/src/hyper/mod.rs index 6a9a77e..807c3b9 100644 --- a/src/hyper/mod.rs +++ b/src/hyper/mod.rs @@ -1,37 +1,40 @@ -use std::{str::FromStr, convert::{TryFrom, Infallible}}; +use std::{convert::{Infallible, TryFrom}, str::FromStr}; use serde::Serialize; pub use hyper::{ header::{HeaderName, HeaderValue}, - Body, + body::Incoming, }; use hyper::{ - header::{InvalidHeaderName, InvalidHeaderValue, CONTENT_TYPE}, - http::{ + body::{Body as BodyTrait, Bytes, Frame, SizeHint}, header::{InvalidHeaderName, InvalidHeaderValue, CONTENT_TYPE}, http::{ method::{InvalidMethod, Method}, request::Builder, uri::{Builder as UriBuilder, InvalidUri, PathAndQuery, Uri}, Error as HTTPError, - }, - Client, Error as HyperError, Response, + }, Error as HyperError, Request, Response }; use std::mem::take; mod connector; -use connector::Connector; +pub(crate) use connector::HyperClient; + +pub(crate) fn get_client() -> HyperClient { + HyperClient::default() +} #[derive(Debug)] pub struct Req { req: Builder, body: Body, + pub(crate) client: Option } pub struct Resp { - resp: Response, + resp: Response, } -impl Into> for crate::Response { - fn into(self) -> Response { +impl Into> for crate::Response { + fn into(self) -> Response { self.0.resp } } @@ -52,6 +55,7 @@ where Ok(crate::Request(Req { req, body: Body::empty(), + client: None })) } } @@ -84,31 +88,32 @@ impl Req { Req { req, body: Body::empty(), + client: None } } - pub async fn send_request(self) -> Result { + pub async fn send_request(mut self) -> Result { let req = self.req.body(self.body)?; - let connector = Connector::new(); - let client = Client::builder().build::<_, Body>(connector); - - let resp = client.request(req).await?; + let resp = if let Some(mut client) = self.client.take() { + client.request(req).await? + }else{ + get_client().request(req).await? + }; Ok(Resp { resp }) } pub fn json(&mut self, json: &T) -> Result<(), Error> { - let bytes = serde_json::to_vec(&json)?; + let bytes = serde_json::to_string(&json)?; self.set_header(CONTENT_TYPE, HeaderValue::from_static("application/json"))?; - self.body = Body::from(bytes); + self.body = bytes.into(); Ok(()) } pub fn form(&mut self, data: &T) -> Result<(), Error> { let query = serde_urlencoded::to_string(data)?; - let bytes = query.into_bytes(); self.set_header( CONTENT_TYPE, HeaderValue::from_static("application/x-www-form-urlencoded"), )?; - self.body = Body::from(bytes); + self.body = query.into(); Ok(()) } pub fn query(&mut self, query: &T) -> Result<(), Error> { @@ -145,7 +150,6 @@ impl Req { } } use hyper::body::Buf; -use hyper::body::{aggregate, to_bytes}; use serde::de::DeserializeOwned; impl Resp { pub fn status(&self) -> u16 { @@ -159,8 +163,12 @@ impl Resp { Ok(serde_json::from_reader(reader)?) } pub async fn bytes(&mut self) -> Result, Error> { - let b = to_bytes(self.resp.body_mut()).await?; - Ok(b.to_vec()) + let mut b = aggregate(self.resp.body_mut()).await?; + let capacity = b.remaining(); + //TODO uninit + let mut v = vec![0;capacity]; + b.copy_to_slice(&mut v); + Ok(v) } pub async fn string(&mut self) -> Result { let b = self.bytes().await?; @@ -174,6 +182,53 @@ impl Resp { } } +struct FracturedBuf(std::collections::VecDeque); +impl Buf for FracturedBuf { + fn remaining(&self) -> usize { + self.0.iter().map(|buf| buf.remaining()).sum() + } + fn chunk(&self) -> &[u8] { + self.0.front().map(Buf::chunk).unwrap_or_default() + } + fn advance(&mut self, mut cnt: usize) { + let bufs = &mut self.0; + while cnt > 0 { + if let Some(front) = bufs.front_mut() { + let rem = front.remaining(); + if rem > cnt { + front.advance(cnt); + return; + } else { + front.advance(rem); + cnt -= rem; + } + } else { + //no data -> panic? + return; + } + bufs.pop_front(); + } + } +} +struct Framed<'a>(&'a mut Incoming); + +impl<'a> futures::Future for Framed<'a> { + type Output = Option, hyper::Error>>; + + fn poll(mut self: std::pin::Pin<&mut Self>, ctx: &mut std::task::Context<'_>) -> std::task::Poll { + std::pin::Pin::new(&mut self.0).poll_frame(ctx) + } +} +async fn aggregate(body: &mut Incoming) -> Result { + let mut v = std::collections::VecDeque::new(); + while let Some(f) = Framed(body).await { + if let Ok(d) = f?.into_data() { + v.push_back(d); + } + } + Ok(FracturedBuf(v)) +} + #[derive(Debug)] pub enum Error { Scheme, @@ -186,6 +241,7 @@ pub enum Error { InvalidHeaderName(InvalidHeaderName), InvalidUri(InvalidUri), Urlencoded(serde_urlencoded::ser::Error), + Io(std::io::Error) } impl std::error::Error for Error {} use std::fmt; @@ -193,19 +249,25 @@ impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Error::Scheme => write!(f, "Scheme"), - Error::Http(i) => write!(f, "{}", i.to_string()), - Error::InvalidQueryString(i) => write!(f, "{}", i.to_string()), - Error::InvalidMethod(i) => write!(f, "{}", i.to_string()), - Error::Hyper(i) => write!(f, "{}", i.to_string()), - Error::Json(i) => write!(f, "{}", i.to_string()), - Error::InvalidHeaderValue(i) => write!(f, "{}", i.to_string()), - Error::InvalidHeaderName(i) => write!(f, "{}", i.to_string()), - Error::InvalidUri(i) => write!(f, "{}", i.to_string()), - Error::Urlencoded(i) => write!(f, "{}", i.to_string()), + Error::Http(i) => write!(f, "{}", i), + Error::InvalidQueryString(i) => write!(f, "{}", i), + Error::InvalidMethod(i) => write!(f, "{}", i), + Error::Hyper(i) => write!(f, "{}", i), + Error::Json(i) => write!(f, "{}", i), + Error::InvalidHeaderValue(i) => write!(f, "{}", i), + Error::InvalidHeaderName(i) => write!(f, "{}", i), + Error::InvalidUri(i) => write!(f, "{}", i), + Error::Urlencoded(i) => write!(f, "{}", i), + Error::Io(i) => write!(f, "{}", i), } } } +impl From for Error { + fn from(e: std::io::Error) -> Self { + Self::Io(e) + } +} impl From for Error { fn from(e: serde_urlencoded::ser::Error) -> Self { Self::Urlencoded(e) @@ -257,3 +319,54 @@ impl From for Error { unreachable!(); } } + +#[derive(Debug)] +pub struct Body(Vec); +impl Body { + fn empty() -> Self { + Self(vec![]) + } +} +impl From for Body { + #[inline] + fn from(t: String) -> Self { + Body(t.into_bytes()) + } +} +impl From> for Body { + #[inline] + fn from(t: Vec) -> Self { + Body(t) + } +} +impl From<&'static [u8]> for Body { + #[inline] + fn from(t: &'static [u8]) -> Self { + Body(t.to_vec()) + } +} +impl From<&'static str> for Body { + #[inline] + fn from(t: &'static str) -> Self { + Body(t.as_bytes().to_vec()) + } +} +impl hyper::body::Body for Body { + type Data = Bytes; + type Error = Infallible; + + fn poll_frame( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + if self.0.is_empty() { + std::task::Poll::Ready(None) + }else{ + let v: Vec = std::mem::take(self.0.as_mut()); + std::task::Poll::Ready(Some(Ok(Frame::data(v.into())))) + } + } + fn size_hint(&self) -> SizeHint { + SizeHint::with_exact(self.0.len() as u64) + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 543cad1..d9089b6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,8 @@ mod imp; #[cfg(any(feature = "use_hyper", feature = "use_async_h1"))] mod tcp; +#[cfg(all(any(feature = "use_hyper", feature = "use_async_h1"), feature = "proxies"))] +pub use tcp::proxy; #[cfg(feature = "use_async_h1")] #[path = "a_h1/mod.rs"] @@ -89,13 +91,8 @@ mod tests { pub(crate) async fn assert_stream(stream: &mut TcpStream, should_be: &[u8]) -> std::io::Result<()> { let l = should_be.len(); let mut req: Vec = vec![0; l]; - stream.read_exact(req.as_mut_slice()).await?; - if req != should_be { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "req not as expected", - )); - } + let _r = stream.read(req.as_mut_slice()).await?; + assert_eq!(req, should_be); Ok(()) } pub(crate) async fn listen_somewhere() -> Result<(TcpListener, u16, String), std::io::Error> { @@ -110,14 +107,6 @@ mod tests { async fn server(listener: TcpListener, host: String, port: u16) -> std::io::Result { let (mut stream, _) = listener.accept().await?; let mut output = Vec::with_capacity(1); - - #[cfg(feature = "use_hyper")] - assert_stream( - &mut stream, - format!("GET / HTTP/1.1\r\nhost: {}:{}\r\n\r\n",host,port).as_bytes(), - ) - .await?; - #[cfg(feature = "use_async_h1")] assert_stream( &mut stream, format!("GET / HTTP/1.1\r\nhost: {}:{}\r\ncontent-length: 0\r\n\r\n",host,port).as_bytes(), @@ -127,7 +116,7 @@ mod tests { stream .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 3\r\n\r\nabc") .await?; - stream.read(&mut output).await?; + let _ = stream.read(&mut output).await?; Ok(true) } block_on(async { @@ -155,7 +144,7 @@ mod tests { #[cfg(feature = "use_hyper")] assert_stream( &mut stream, - format!("PUT / HTTP/1.1\r\ncookies: jo\r\nhost: {}:{}\r\n\r\n",host,port).as_bytes(), + format!("PUT / HTTP/1.1\r\ncookies: jo\r\nhost: {}:{}\r\ncontent-length: 0\r\n\r\n",host,port).as_bytes(), ) .await?; diff --git a/src/tcp/http.rs b/src/tcp/http.rs index 5e487dd..ef5bf45 100644 --- a/src/tcp/http.rs +++ b/src/tcp/http.rs @@ -27,7 +27,7 @@ pub async fn connect_via_http_prx( "" //TODO Auth ) .into_bytes(); - socket.write(&buf).await?; + socket.write_all(&buf).await?; let mut buffer = [0; 40]; let r = socket.read(&mut buffer).await?; diff --git a/src/tcp/mod.rs b/src/tcp/mod.rs index e01b08c..03c21ca 100644 --- a/src/tcp/mod.rs +++ b/src/tcp/mod.rs @@ -20,28 +20,28 @@ use async_std::{ }; #[cfg(all(feature = "use_async_h1", feature = "proxies"))] use http_types::Url as Uri; -#[cfg(feature = "use_hyper")] -use hyper::client::connect::Connection; #[cfg(all(feature = "use_hyper", feature = "proxies"))] use hyper::http::uri::Uri; #[cfg(feature = "use_hyper")] use tokio::{ - io::{AsyncRead, AsyncWrite as Write, ReadBuf}, + io::{AsyncRead as _, AsyncWrite as _}, net::TcpStream, }; +#[cfg(feature = "use_hyper")] +use hyper::rt::{Read, Write, ReadBufCursor}; #[cfg(any(feature = "async_native_tls",feature = "hyper_native_tls"))] use async_native_tls::{TlsConnector, TlsStream}; #[cfg(all(feature = "rustls", feature = "use_async_h1"))] use futures_rustls::{ client::TlsStream, - rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}, + rustls::{ClientConfig, RootCertStore, pki_types::ServerName}, TlsConnector, }; #[cfg(all(feature = "rustls", feature = "use_hyper"))] use tokio_rustls::{ client::TlsStream, - rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}, + rustls::{ClientConfig, RootCertStore, pki_types::ServerName}, TlsConnector, }; #[cfg(feature = "rustls")] @@ -60,106 +60,182 @@ enum State { Plain(TcpStream), } -//static connect : Box dyn Future>> = Box::new(connect_w_proxy); +#[cfg(feature = "proxies")] +pub mod proxy { + use async_trait::async_trait; + use super::*; -/* - http_proxy, HTTPS_PROXY + /// Sets the global proxy to a `&'static Proxy`. + pub fn set_proxy(proxy: &'static dyn Proxy) { + unsafe { + GLOBAL_PROXY = proxy; + } + } + /// Sets the global proxy to a `Box`. + /// + /// This is a simple convenience wrapper over `set_proxy`, which takes a + /// `Box` rather than a `&'static Proxy`. See the documentation for + /// [`set_proxy`] for more details. + pub fn set_boxed_proxy(proxy: Box) { + set_proxy(Box::leak(proxy)) + } + /// Returns a reference to the proxy. + pub fn proxy() -> &'static dyn Proxy { + unsafe { GLOBAL_PROXY } + } + static mut GLOBAL_PROXY: &dyn Proxy = &EnvProxy; -They should be set for protocol-specific proxies. General proxy should be -set with + #[async_trait] + pub trait Proxy: Sync + Send { + async fn connect_w_proxy(&self, host: &str, port: u16, tls: bool) -> io::Result; + } - ALL_PROXY + pub struct NoProxy; + #[async_trait] + impl Proxy for NoProxy { + async fn connect_w_proxy(&self, host: &str, port: u16, _tls: bool) -> io::Result + { + TcpStream::connect((host, port)).await + } + } + /// + /// `http_proxy`, `HTTPS_PROXY` should be set for protocol-specific proxies. + /// General proxy should be set with `ALL_PROXY` + /// + /// A comma-separated list of host names that shouldn't go through any proxy is + /// set in (only an asterisk, '*' matches all hosts) `NO_PROXY` + pub struct EnvProxy; + #[async_trait] + impl Proxy for EnvProxy { + async fn connect_w_proxy(&self, host: &str, port: u16, tls: bool) -> io::Result { + let mut prx = std::env::var("ALL_PROXY") + .or_else(|_| std::env::var("all_proxy")) + .ok(); + if prx.is_none() && tls { + prx = std::env::var("HTTPS_PROXY") + .or_else(|_| std::env::var("https_proxy")) + .ok(); + } + if prx.is_none() && !tls { + prx = std::env::var("HTTP_PROXY") + .or_else(|_| std::env::var("http_proxy")) + .ok(); + } + if let Ok(no_proxy) = std::env::var("NO_PROXY").or_else(|_| std::env::var("no_proxy")) { + for h in no_proxy.split(',') { + match h.trim() { + a if a == host => {} + "*" => {} + _ => continue, + } + log::debug!("using no proxy due to env NO_PROXY"); + prx = None; + break; + } + } + match prx { + None => TcpStream::connect((host, port)).await, + Some(proxy) => { + let url = proxy + .parse::() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + #[cfg(feature = "use_hyper")] + let (phost, scheme) = (url.host(), url.scheme_str()); + #[cfg(feature = "use_async_h1")] + let (phost, scheme) = (url.host_str(), Some(url.scheme())); -A comma-separated list of host names that shouldn't go through any proxy is -set in (only an asterisk, '*' matches all hosts) + let phost = match phost { + Some(s) => s, + None => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "missing proxy host", + )); + } + }; + #[cfg(feature = "use_hyper")] + let pport = url.port().map(|p| p.as_u16()); + #[cfg(feature = "use_async_h1")] + let pport = url.port(); - NO_PROXY -*/ -#[cfg(feature = "proxies")] -pub async fn connect_w_proxy(host: &str, port: u16, tls: bool) -> io::Result { - let mut prx = std::env::var("ALL_PROXY") - .or_else(|_| std::env::var("all_proxy")) - .ok(); - if prx == None && tls { - prx = std::env::var("HTTPS_PROXY") - .or_else(|_| std::env::var("https_proxy")) - .ok(); - } - if prx == None && !tls { - prx = std::env::var("HTTP_PROXY") - .or_else(|_| std::env::var("http_proxy")) - .ok(); - } - if let Ok(no_proxy) = std::env::var("NO_PROXY").or_else(|_| std::env::var("no_proxy")) { - for h in no_proxy.split(",") { - match h.trim() { - a if a == host => {} - "*" => {} - _ => continue, + let pport = match pport { + Some(port) => port, + None => match scheme { + Some("https") => 443, + Some("http") => 80, + Some("socks5") => 1080, + Some("socks5h") => 1080, + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "missing proxy port", + )) + } + }, + }; + log::info!("using proxy {}:{}", phost, pport); + match scheme { + Some("http") => connect_via_http_prx(host, port, phost, pport).await, + Some(socks5) if socks5 == "socks5" || socks5 == "socks5h" => { + connect_via_socks_prx(host, port, phost, pport, socks5 == "socks5h").await + } + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "unsupported proxy scheme", + )) + } + } + } } - log::debug!("using no proxy due to env NO_PROXY"); - prx = None; - break; } } - match prx { - None => TcpStream::connect((host, port)).await, - Some(proxy) => { - let url = proxy - .parse::() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; - #[cfg(feature = "use_hyper")] - let (phost, scheme) = (url.host(), url.scheme_str()); - #[cfg(feature = "use_async_h1")] - let (phost, scheme) = (url.host_str(), Some(url.scheme())); - let phost = match phost { - Some(s) => s, - None => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "missing proxy host", - )); - } - }; - #[cfg(feature = "use_hyper")] - let pport = url.port().map(|p| p.as_u16()); - #[cfg(feature = "use_async_h1")] - let pport = url.port(); + #[cfg(test)] + mod tests { + use crate::tests::{assert_stream, TcpListener, spawn, block_on, listen_somewhere, WriteExt}; + #[test] + fn prx_from_env() { + async fn server(listener: TcpListener) -> std::io::Result { + let (mut stream, _) = listener.accept().await?; - let pport = match pport { - Some(port) => port, - None => match scheme { - Some("https") => 443, - Some("http") => 80, - Some("socks5") => 1080, - Some("socks5h") => 1080, - _ => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "missing proxy port", - )) - } - }, - }; - log::info!("using proxy {}:{}", phost, pport); - match scheme { - Some("http") => connect_via_http_prx(host, port, phost, pport).await, - Some(socks5) if socks5 == "socks5" || socks5 == "socks5h" => { - connect_via_socks_prx(host, port, phost, pport, socks5 == "socks5h").await - } - _ => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "unsupported proxy scheme", - )) - } + assert_stream( + &mut stream, + format!("CONNECT whatever:80 HTTP/1.1\r\nHost: whatever:80\r\n\r\n").as_bytes(), + ) + .await?; + stream.write_all(b"HTTP/1.1 200 Connected\r\n\r\n").await?; + + assert_stream( + &mut stream, + format!("GET /bla HTTP/1.1\r\nhost: whatever\r\ncontent-length: 0\r\n\r\n").as_bytes(), + ) + .await?; + stream + .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 3\r\n\r\nabc") + .await?; + + Ok(true) } + block_on(async { + let (listener, pport, phost) = listen_somewhere().await?; + std::env::set_var("HTTP_PROXY", format!("http://{phost}:{pport}/")); + std::env::set_var("NO_PROXY", &phost); + let t = spawn(server(listener)); + + let r = crate::Request::get("http://whatever/bla"); + let mut aw = r.exec().await?; + + assert_eq!(aw.status_code(), 200, "wrong status"); + assert_eq!(aw.text().await?, "abc", "wrong text"); + assert!(t.await?, "not cool"); + Ok(()) + }) + .unwrap(); } } } - #[cfg(any( feature = "rustls", feature = "hyper_native_tls", @@ -169,16 +245,9 @@ fn get_tls_connector() -> io::Result{ #[cfg(feature = "rustls")] { let mut root_store = RootCertStore::empty(); - root_store.add_trust_anchors(TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + root_store.extend(TLS_SERVER_ROOTS.iter().cloned()); let mut config = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(); @@ -186,7 +255,7 @@ fn get_tls_connector() -> io::Result{ config.alpn_protocols.push(b"h2".to_vec()); config.alpn_protocols.push(b"http/1.1".to_vec()); - return Ok(TlsConnector::from(Arc::new(config))); + Ok(TlsConnector::from(Arc::new(config))) } #[cfg(any(feature = "async_native_tls",feature = "hyper_native_tls"))] return Ok(TlsConnector::new()); @@ -195,7 +264,7 @@ fn get_tls_connector() -> io::Result{ impl Stream { pub async fn connect(host: &str, port: u16, tls: bool) -> io::Result { #[cfg(feature = "proxies")] - let tcp = connect_w_proxy(host, port, tls).await?; + let tcp = proxy::proxy().connect_w_proxy(host, port, tls).await?; #[cfg(not(feature = "proxies"))] let tcp = TcpStream::connect((host, port)).await?; log::trace!("connected to {}:{}", host, port); @@ -206,7 +275,7 @@ impl Stream { #[cfg(feature = "rustls")] let host = ServerName::try_from(host).map_err(|_e| { io::Error::new(io::ErrorKind::InvalidInput, "Invalid DNS name") - })?; + })?.to_owned(); let tlsc = get_tls_connector()?; let tls = tlsc.connect(host, tcp).await; @@ -244,22 +313,16 @@ impl Stream { } #[cfg(feature = "use_hyper")] -impl Connection for Stream { - fn connected(&self) -> hyper::client::connect::Connected { - #[cfg_attr(not(feature = "rustls"), allow(unused_mut))] - let mut c = hyper::client::connect::Connected::new(); - - match self.state { - #[cfg(feature = "rustls")] - State::Tls(ref t) => { - let (_, s) = t.get_ref(); - if Some(&b"h2"[..]) == s.alpn_protocol() { - c = c.negotiated_h2(); - } +impl Stream { + fn get_proto(&self) -> hyper::Version { + #[cfg(feature = "rustls")] + if let State::Tls(ref t) = self.state { + let (_, s) = t.get_ref(); + if Some(&b"h2"[..]) == s.alpn_protocol() { + return hyper::Version::HTTP_2; } - _ => {} } - c + hyper::Version::HTTP_11 } } @@ -325,8 +388,8 @@ impl Write for Stream { } } } -#[cfg(feature = "use_async_h1")] impl Read for Stream { + #[cfg(feature = "use_async_h1")] fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -343,23 +406,32 @@ impl Read for Stream { State::Plain(ref mut t) => Pin::new(t).poll_read(cx, buf), } } -} -#[cfg(feature = "use_hyper")] -impl AsyncRead for Stream { + #[cfg(feature = "use_hyper")] fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + mut buf: ReadBufCursor<'_>, ) -> Poll> { let pin = self.get_mut(); - match pin.state { - #[cfg(any( - feature = "rustls", - feature = "hyper_native_tls", - feature = "async_native_tls" - ))] - State::Tls(ref mut t) => Pin::new(t).poll_read(cx, buf), - State::Plain(ref mut t) => Pin::new(t).poll_read(cx, buf), + let f = { + let mut tbuf = tokio::io::ReadBuf::uninit(unsafe { buf.as_mut() }); + let p = match pin.state { + #[cfg(any( + feature = "rustls", + feature = "hyper_native_tls", + feature = "async_native_tls" + ))] + State::Tls(ref mut t) => Pin::new(t).poll_read(cx, &mut tbuf), + State::Plain(ref mut t) => Pin::new(t).poll_read(cx, &mut tbuf), + }; + match p { + Poll::Ready(Ok(())) => tbuf.filled().len(), + o => return o, + } + }; + unsafe { + buf.advance(f); } + Poll::Ready(Ok(())) } }