Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 369f6e0

Browse files
committed
tokio-postgres-openssl
1 parent 5c89b35 commit 369f6e0

File tree

4 files changed

+203
-0
lines changed

4 files changed

+203
-0
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ members = [
77
"postgres-openssl",
88
"postgres-native-tls",
99
"tokio-postgres",
10+
"tokio-postgres-openssl",
1011
]
1112

1213
[patch.crates-io]

tokio-postgres-openssl/Cargo.toml

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "tokio-postgres-openssl"
3+
version = "0.1.0"
4+
authors = ["Steven Fackler <[email protected]>"]
5+
6+
[dependencies]
7+
bytes = "0.4"
8+
futures = "0.1"
9+
openssl = "0.10"
10+
tokio-io = "0.1"
11+
tokio-openssl = "0.2"
12+
tokio-postgres = { version = "0.3", path = "../tokio-postgres" }
13+
14+
[dev-dependencies]
15+
tokio = "0.1.7"

tokio-postgres-openssl/src/lib.rs

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
extern crate bytes;
2+
extern crate futures;
3+
extern crate openssl;
4+
extern crate tokio_io;
5+
extern crate tokio_openssl;
6+
extern crate tokio_postgres;
7+
8+
#[cfg(test)]
9+
extern crate tokio;
10+
11+
use bytes::{Buf, BufMut};
12+
use futures::{Future, IntoFuture, Poll};
13+
use openssl::error::ErrorStack;
14+
use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod};
15+
use std::error::Error;
16+
use std::io::{self, Read, Write};
17+
use tokio_io::{AsyncRead, AsyncWrite};
18+
use tokio_openssl::ConnectConfigurationExt;
19+
use tokio_postgres::tls::{Socket, TlsConnect, TlsStream};
20+
21+
#[cfg(test)]
22+
mod test;
23+
24+
pub struct TlsConnector {
25+
connector: SslConnector,
26+
callback: Box<Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + Sync + Send>,
27+
}
28+
29+
impl TlsConnector {
30+
pub fn new() -> Result<TlsConnector, ErrorStack> {
31+
let connector = SslConnector::builder(SslMethod::tls())?.build();
32+
Ok(TlsConnector::with_connector(connector))
33+
}
34+
35+
pub fn with_connector(connector: SslConnector) -> TlsConnector {
36+
TlsConnector {
37+
connector,
38+
callback: Box::new(|_| Ok(())),
39+
}
40+
}
41+
42+
pub fn set_callback<F>(&mut self, f: F)
43+
where
44+
F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send,
45+
{
46+
self.callback = Box::new(f);
47+
}
48+
}
49+
50+
impl TlsConnect for TlsConnector {
51+
fn connect(
52+
&self,
53+
domain: &str,
54+
socket: Socket,
55+
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Sync + Send> {
56+
let f = self
57+
.connector
58+
.configure()
59+
.and_then(|mut ssl| (self.callback)(&mut ssl).map(|_| ssl))
60+
.map_err(|e| {
61+
let e: Box<Error + Sync + Send> = Box::new(e);
62+
e
63+
})
64+
.into_future()
65+
.and_then({
66+
let domain = domain.to_string();
67+
move |ssl| {
68+
ssl.connect_async(&domain, socket)
69+
.map(|s| {
70+
let s: Box<TlsStream> = Box::new(SslStream(s));
71+
s
72+
})
73+
.map_err(|e| {
74+
let e: Box<Error + Sync + Send> = Box::new(e);
75+
e
76+
})
77+
}
78+
});
79+
Box::new(f)
80+
}
81+
}
82+
83+
struct SslStream(tokio_openssl::SslStream<Socket>);
84+
85+
impl Read for SslStream {
86+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
87+
self.0.read(buf)
88+
}
89+
}
90+
91+
impl AsyncRead for SslStream {
92+
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
93+
self.0.prepare_uninitialized_buffer(buf)
94+
}
95+
96+
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
97+
where
98+
B: BufMut,
99+
{
100+
self.0.read_buf(buf)
101+
}
102+
}
103+
104+
impl Write for SslStream {
105+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
106+
self.0.write(buf)
107+
}
108+
109+
fn flush(&mut self) -> io::Result<()> {
110+
self.0.flush()
111+
}
112+
}
113+
114+
impl AsyncWrite for SslStream {
115+
fn shutdown(&mut self) -> Poll<(), io::Error> {
116+
self.0.shutdown()
117+
}
118+
119+
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
120+
where
121+
B: Buf,
122+
{
123+
self.0.write_buf(buf)
124+
}
125+
}
126+
127+
impl TlsStream for SslStream {}

tokio-postgres-openssl/src/test.rs

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
use futures::{Future, Stream};
2+
use openssl::ssl::{SslConnector, SslMethod};
3+
use tokio::runtime::current_thread::Runtime;
4+
use tokio_postgres::{self, TlsMode};
5+
6+
use TlsConnector;
7+
8+
fn smoke_test(url: &str, tls: TlsMode) {
9+
let mut runtime = Runtime::new().unwrap();
10+
11+
let handshake = tokio_postgres::connect(url.parse().unwrap(), tls);
12+
let (mut client, connection) = runtime.block_on(handshake).unwrap();
13+
let connection = connection.map_err(|e| panic!("{}", e));
14+
runtime.handle().spawn(connection).unwrap();
15+
16+
let prepare = client.prepare("SELECT 1::INT4");
17+
let statement = runtime.block_on(prepare).unwrap();
18+
let select = client.query(&statement, &[]).collect().map(|rows| {
19+
assert_eq!(rows.len(), 1);
20+
assert_eq!(rows[0].get::<_, i32>(0), 1);
21+
});
22+
runtime.block_on(select).unwrap();
23+
24+
drop(statement);
25+
drop(client);
26+
runtime.run().unwrap();
27+
}
28+
29+
#[test]
30+
fn require() {
31+
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
32+
builder.set_ca_file("../test/server.crt").unwrap();
33+
let connector = TlsConnector::with_connector(builder.build());
34+
smoke_test(
35+
"postgres://ssl_user@localhost:5433/postgres",
36+
TlsMode::Require(Box::new(connector)),
37+
);
38+
}
39+
40+
#[test]
41+
fn prefer() {
42+
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
43+
builder.set_ca_file("../test/server.crt").unwrap();
44+
let connector = TlsConnector::with_connector(builder.build());
45+
smoke_test(
46+
"postgres://ssl_user@localhost:5433/postgres",
47+
TlsMode::Prefer(Box::new(connector)),
48+
);
49+
}
50+
51+
#[test]
52+
fn scram_user() {
53+
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
54+
builder.set_ca_file("../test/server.crt").unwrap();
55+
let connector = TlsConnector::with_connector(builder.build());
56+
smoke_test(
57+
"postgres://scram_user:password@localhost:5433/postgres",
58+
TlsMode::Require(Box::new(connector)),
59+
);
60+
}

0 commit comments

Comments
 (0)