Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ pub struct ClientQuicConfig {
pub sni_hostname: NoneOrOne<String>,
#[serde(alias = "alpn_protocol", default)]
pub alpn_protocols: NoneOrSome<String>,
#[serde(default)]
pub key: Option<String>,
#[serde(default)]
pub cert: Option<String>,
}

impl Default for ClientQuicConfig {
Expand All @@ -367,6 +371,8 @@ impl Default for ClientQuicConfig {
verify: true,
sni_hostname: NoneOrOne::Unspecified,
alpn_protocols: NoneOrSome::Unspecified,
key: None,
cert: None,
}
}
}
Expand Down Expand Up @@ -420,6 +426,10 @@ pub struct TlsClientConfig {
pub sni_hostname: NoneOrOne<String>,
#[serde(alias = "alpn_protocol", default)]
pub alpn_protocols: NoneOrSome<String>,
#[serde(default)]
pub key: Option<String>,
#[serde(default)]
pub cert: Option<String>,
pub protocol: Box<ClientProxyConfig>,
}

Expand Down Expand Up @@ -680,11 +690,21 @@ fn validate_client_config(client_config: &mut ClientConfig) -> std::io::Result<(
));
}

if client_config.transport != Transport::Quic && client_config.quic_settings.is_some() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"QUIC transport is not selected but QUIC settings specified",
));
if let Some(ref quic_config) = client_config.quic_settings {
if client_config.transport != Transport::Quic {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"QUIC transport is not selected but QUIC settings specified",
));
}

let ClientQuicConfig { cert, key, .. } = quic_config;
if cert.is_none() != key.is_none() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Both client cert and key have to be specified, or both have to be omitted",
));
}
}

#[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))]
Expand All @@ -695,6 +715,23 @@ fn validate_client_config(client_config: &mut ClientConfig) -> std::io::Result<(
));
}

validate_client_proxy_config(&client_config.protocol)?;

Ok(())
}

fn validate_client_proxy_config(client_proxy_config: &ClientProxyConfig) -> std::io::Result<()> {
match client_proxy_config {
ClientProxyConfig::Tls(TlsClientConfig { cert, key, .. }) => {
if cert.is_none() != key.is_none() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Both client cert and key have to be specified, or both have to be omitted",
));
}
}
_ => {}
}
Ok(())
}

Expand Down
25 changes: 19 additions & 6 deletions src/rustls_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,39 @@ pub fn create_client_config(
verify: bool,
alpn_protocols: &[String],
enable_sni: bool,
client_key_and_cert: Option<(Vec<u8>, Vec<u8>)>,
) -> rustls::ClientConfig {
let builder = rustls::ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.unwrap();

let mut config = if !verify {
let builder = if !verify {
builder
.dangerous()
.with_custom_certificate_verifier(get_disabled_verifier())
.with_no_client_auth()
} else {
let root_store = rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
};
builder
.with_root_certificates(root_store)
.with_no_client_auth()
builder.with_root_certificates(root_store)
};

let mut config = match client_key_and_cert {
Some((key_bytes, cert_bytes)) => {
let certs = vec![
rustls::pki_types::CertificateDer::from_pem_slice(&cert_bytes)
.unwrap()
.into_owned(),
];

let privkey = rustls::pki_types::PrivateKeyDer::from_pem_slice(&key_bytes).unwrap();
builder
.with_client_auth_cert(certs, privkey)
.expect("Could not parse client certificate")
}
None => builder.with_no_client_auth(),
};

config.alpn_protocols = alpn_protocols
Expand Down Expand Up @@ -101,7 +115,6 @@ pub fn create_server_config(
.into_owned(),
];

// there's no into_owned for PrivateKeyDer.
let privkey = rustls::pki_types::PrivateKeyDer::from_pem_slice(key_bytes).unwrap();

let builder = rustls::ServerConfig::builder_with_provider(Arc::new(
Expand Down
17 changes: 17 additions & 0 deletions src/tcp/tcp_client_connector.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::io::Read;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;

Expand Down Expand Up @@ -51,6 +52,8 @@ impl TcpClientConnector {
verify,
alpn_protocols,
sni_hostname,
key,
cert,
} = client_config.quic_settings.unwrap_or_default();

let sni_hostname = if sni_hostname.is_unspecified() {
Expand All @@ -72,10 +75,24 @@ impl TcpClientConnector {
}
};

let key_and_cert_bytes = key.zip(cert).map(|(key, cert)| {
// TODO: do this asynchronously
let mut cert_file = std::fs::File::open(&cert).unwrap();
let mut cert_bytes = vec![];
cert_file.read_to_end(&mut cert_bytes).unwrap();

let mut key_file = std::fs::File::open(&key).unwrap();
let mut key_bytes = vec![];
key_file.read_to_end(&mut key_bytes).unwrap();

(key_bytes, cert_bytes)
});

let rustls_client_config = create_client_config(
verify,
&alpn_protocols.into_vec(),
sni_hostname.is_some(),
key_and_cert_bytes,
);

let quic_client_config = quinn::crypto::rustls::QuicClientConfig::with_initial(
Expand Down
16 changes: 16 additions & 0 deletions src/tcp/tcp_handler_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ pub fn create_tcp_client_handler(
sni_hostname,
alpn_protocols,
protocol,
key,
cert,
} = tls_client_config;

let sni_hostname = if sni_hostname.is_unspecified() {
Expand All @@ -269,10 +271,24 @@ pub fn create_tcp_client_handler(
sni_hostname.into_option()
};

let key_and_cert_bytes = key.zip(cert).map(|(key, cert)| {
// TODO: do this asynchronously
let mut cert_file = std::fs::File::open(&cert).unwrap();
let mut cert_bytes = vec![];
cert_file.read_to_end(&mut cert_bytes).unwrap();

let mut key_file = std::fs::File::open(&key).unwrap();
let mut key_bytes = vec![];
key_file.read_to_end(&mut key_bytes).unwrap();

(key_bytes, cert_bytes)
});

let client_config = Arc::new(create_client_config(
verify,
&alpn_protocols.into_vec(),
sni_hostname.is_some(),
key_and_cert_bytes,
));

let server_name = match sni_hostname {
Expand Down