Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 520aaca

Browse files
authored
Merge pull request sfackler#339 from sfackler/keepalive
Support TCP keepalive
2 parents 01e8206 + d0c111d commit 520aaca

File tree

7 files changed

+163
-189
lines changed

7 files changed

+163
-189
lines changed

postgres-shared/src/params/mod.rs

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Connection parameters
22
use std::error::Error;
3-
use std::path::PathBuf;
43
use std::mem;
4+
use std::path::PathBuf;
55
use std::time::Duration;
66

77
use params::url::Url;
@@ -45,6 +45,7 @@ pub struct ConnectParams {
4545
database: Option<String>,
4646
options: Vec<(String, String)>,
4747
connect_timeout: Option<Duration>,
48+
keepalive: Option<Duration>,
4849
}
4950

5051
impl ConnectParams {
@@ -86,6 +87,13 @@ impl ConnectParams {
8687
pub fn connect_timeout(&self) -> Option<Duration> {
8788
self.connect_timeout
8889
}
90+
91+
/// The interval at which TCP keepalive messages are sent on the socket.
92+
///
93+
/// This is ignored for Unix sockets.
94+
pub fn keepalive(&self) -> Option<Duration> {
95+
self.keepalive
96+
}
8997
}
9098

9199
/// A builder for `ConnectParams`.
@@ -95,6 +103,7 @@ pub struct Builder {
95103
database: Option<String>,
96104
options: Vec<(String, String)>,
97105
connect_timeout: Option<Duration>,
106+
keepalive: Option<Duration>,
98107
}
99108

100109
impl Builder {
@@ -106,6 +115,7 @@ impl Builder {
106115
database: None,
107116
options: vec![],
108117
connect_timeout: None,
118+
keepalive: None,
109119
}
110120
}
111121

@@ -142,6 +152,12 @@ impl Builder {
142152
self
143153
}
144154

155+
/// Sets the keepalive interval.
156+
pub fn keepalive(&mut self, keepalive: Option<Duration>) -> &mut Builder {
157+
self.keepalive = keepalive;
158+
self
159+
}
160+
145161
/// Constructs a `ConnectParams` from the builder.
146162
pub fn build(&mut self, host: Host) -> ConnectParams {
147163
ConnectParams {
@@ -151,6 +167,7 @@ impl Builder {
151167
database: self.database.take(),
152168
options: mem::replace(&mut self.options, vec![]),
153169
connect_timeout: self.connect_timeout,
170+
keepalive: self.keepalive,
154171
}
155172
}
156173
}
@@ -188,11 +205,12 @@ impl IntoConnectParams for Url {
188205
host,
189206
port,
190207
user,
191-
path: url::Path {
192-
path,
193-
query: options,
194-
..
195-
},
208+
path:
209+
url::Path {
210+
path,
211+
query: options,
212+
..
213+
},
196214
..
197215
} = self;
198216

@@ -218,6 +236,11 @@ impl IntoConnectParams for Url {
218236
let timeout = Duration::from_secs(timeout);
219237
builder.connect_timeout(Some(timeout));
220238
}
239+
"keepalive" => {
240+
let keepalive = value.parse().map_err(|_| "invalid keepalive")?;
241+
let keepalive = Duration::from_secs(keepalive);
242+
builder.keepalive(Some(keepalive));
243+
}
221244
_ => {
222245
builder.option(&name, &value);
223246
}

postgres/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ no-logging = []
5858
bytes = "0.4"
5959
fallible-iterator = "0.1.3"
6060
log = "0.4"
61-
socket2 = "0.3"
61+
socket2 = { version = "0.3.5", features = ["unix"] }
6262

6363
openssl = { version = "0.9.23", optional = true }
6464
native-tls = { version = "0.1", optional = true }

postgres/src/priv_io.rs

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
1-
use std::io::{self, BufWriter, Read, Write};
2-
use std::net::{ToSocketAddrs, SocketAddr};
3-
use std::time::Duration;
4-
use std::result;
51
use bytes::{BufMut, BytesMut};
2+
use postgres_protocol::message::backend;
3+
use postgres_protocol::message::frontend;
4+
use socket2::{Domain, SockAddr, Socket, Type};
5+
use std::io::{self, BufWriter, Read, Write};
6+
use std::net::{SocketAddr, ToSocketAddrs};
67
#[cfg(unix)]
7-
use std::os::unix::net::UnixStream;
8-
#[cfg(unix)]
9-
use std::os::unix::io::{AsRawFd, RawFd, FromRawFd, IntoRawFd};
8+
use std::os::unix::io::{AsRawFd, RawFd};
109
#[cfg(windows)]
1110
use std::os::windows::io::{AsRawSocket, RawSocket};
12-
use postgres_protocol::message::frontend;
13-
use postgres_protocol::message::backend;
14-
use socket2::{Socket, SockAddr, Domain, Type};
11+
use std::result;
12+
use std::time::Duration;
1513

16-
use {Result, TlsMode};
1714
use error;
18-
use tls::TlsStream;
1915
use params::{ConnectParams, Host};
16+
use tls::TlsStream;
17+
use {Result, TlsMode};
2018

2119
const INITIAL_CAPACITY: usize = 8 * 1024;
2220

@@ -61,9 +59,10 @@ impl MessageStream {
6159

6260
fn read_in(&mut self) -> io::Result<()> {
6361
self.in_buf.reserve(1);
64-
match self.stream.get_mut().read(
65-
unsafe { self.in_buf.bytes_mut() },
66-
) {
62+
match self.stream
63+
.get_mut()
64+
.read(unsafe { self.in_buf.bytes_mut() })
65+
{
6766
Ok(0) => Err(io::Error::new(
6867
io::ErrorKind::UnexpectedEof,
6968
"unexpected EOF",
@@ -88,8 +87,11 @@ impl MessageStream {
8887
match r {
8988
Ok(()) => {}
9089
Err(ref e)
91-
if e.kind() == io::ErrorKind::WouldBlock ||
92-
e.kind() == io::ErrorKind::TimedOut => return Ok(None),
90+
if e.kind() == io::ErrorKind::WouldBlock
91+
|| e.kind() == io::ErrorKind::TimedOut =>
92+
{
93+
return Ok(None)
94+
}
9395
Err(e) => return Err(e),
9496
}
9597
}
@@ -184,6 +186,9 @@ fn open_socket(params: &ConnectParams) -> Result<Socket> {
184186
SocketAddr::V6(_) => Domain::ipv6(),
185187
};
186188
let socket = Socket::new(domain, Type::stream(), None)?;
189+
if let Some(keepalive) = params.keepalive() {
190+
socket.set_keepalive(Some(keepalive))?;
191+
}
187192
let addr = SockAddr::from(addr);
188193
let r = match params.connect_timeout() {
189194
Some(timeout) => socket.connect_timeout(&addr, timeout),
@@ -195,33 +200,28 @@ fn open_socket(params: &ConnectParams) -> Result<Socket> {
195200
}
196201
}
197202

198-
Err(
199-
error
200-
.unwrap_or_else(|| {
201-
io::Error::new(
202-
io::ErrorKind::InvalidInput,
203-
"could not resolve any addresses",
204-
)
205-
})
206-
.into(),
207-
)
203+
Err(error
204+
.unwrap_or_else(|| {
205+
io::Error::new(
206+
io::ErrorKind::InvalidInput,
207+
"could not resolve any addresses",
208+
)
209+
})
210+
.into())
208211
}
209212
#[cfg(unix)]
210213
Host::Unix(ref path) => {
211214
let path = path.join(&format!(".s.PGSQL.{}", port));
212-
Ok(UnixStream::connect(&path).map(|s| unsafe {
213-
Socket::from_raw_fd(s.into_raw_fd())
214-
})?)
215+
let socket = Socket::new(Domain::unix(), Type::stream(), None)?;
216+
let addr = SockAddr::unix(path)?;
217+
socket.connect(&addr)?;
218+
Ok(socket)
215219
}
216220
#[cfg(not(unix))]
217-
Host::Unix(..) => {
218-
Err(
219-
io::Error::new(
220-
io::ErrorKind::InvalidInput,
221-
"unix sockets are not supported on this system",
222-
).into(),
223-
)
224-
}
221+
Host::Unix(..) => Err(io::Error::new(
222+
io::ErrorKind::InvalidInput,
223+
"unix sockets are not supported on this system",
224+
).into()),
225225
}
226226
}
227227

0 commit comments

Comments
 (0)