Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1788a03

Browse files
committed
notification/notice support
1 parent 997b5e0 commit 1788a03

File tree

3 files changed

+102
-19
lines changed

3 files changed

+102
-19
lines changed

tokio-postgres/src/lib.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub use postgres_shared::{error, params, types};
3434
#[doc(inline)]
3535
pub use postgres_shared::{CancelData, Notification};
3636

37-
use error::Error;
37+
use error::{DbError, Error};
3838
use params::ConnectParams;
3939
use tls::TlsConnect;
4040
use types::{FromSql, ToSql, Type};
@@ -104,6 +104,10 @@ impl Connection {
104104
pub fn parameter(&self, name: &str) -> Option<&str> {
105105
self.0.parameter(name)
106106
}
107+
108+
pub fn poll_message(&mut self) -> Poll<Option<AsyncMessage>, Error> {
109+
self.0.poll_message()
110+
}
107111
}
108112

109113
impl Future for Connection {
@@ -115,6 +119,13 @@ impl Future for Connection {
115119
}
116120
}
117121

122+
pub enum AsyncMessage {
123+
Notice(DbError),
124+
Notification(Notification),
125+
#[doc(hidden)]
126+
__NonExhaustive,
127+
}
128+
118129
#[must_use = "futures do nothing unless polled"]
119130
pub struct CancelQuery(proto::CancelFuture);
120131

tokio-postgres/src/proto/connection.rs

+31-16
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@ use std::collections::{HashMap, VecDeque};
66
use std::io;
77
use tokio_codec::Framed;
88

9-
use disconnected;
10-
use error::{self, Error};
9+
use error::{self, DbError, Error};
1110
use proto::codec::PostgresCodec;
1211
use tls::TlsStream;
13-
use {bad_response, CancelData};
12+
use {bad_response, disconnected, AsyncMessage, CancelData, Notification};
1413

1514
pub struct Request {
1615
pub messages: Vec<u8>,
@@ -71,10 +70,10 @@ impl Connection {
7170
self.stream.poll()
7271
}
7372

74-
fn poll_read(&mut self) -> Result<(), Error> {
73+
fn poll_read(&mut self) -> Result<Option<AsyncMessage>, Error> {
7574
if self.state != State::Active {
7675
trace!("poll_read: done");
77-
return Ok(());
76+
return Ok(None);
7877
}
7978

8079
loop {
@@ -85,14 +84,22 @@ impl Connection {
8584
}
8685
Async::NotReady => {
8786
trace!("poll_read: waiting on response");
88-
return Ok(());
87+
return Ok(None);
8988
}
9089
};
9190

9291
let message = match message {
93-
Message::NoticeResponse(_) | Message::NotificationResponse(_) => {
94-
// FIXME handle these
95-
continue;
92+
Message::NoticeResponse(body) => {
93+
let error = DbError::new(&mut body.fields())?;
94+
return Ok(Some(AsyncMessage::Notice(error)));
95+
}
96+
Message::NotificationResponse(body) => {
97+
let notification = Notification {
98+
process_id: body.process_id(),
99+
channel: body.channel()?.to_string(),
100+
payload: body.message()?.to_string(),
101+
};
102+
return Ok(Some(AsyncMessage::Notification(notification)));
96103
}
97104
Message::ParameterStatus(body) => {
98105
self.parameters
@@ -127,7 +134,7 @@ impl Connection {
127134
self.responses.push_front(sender);
128135
self.pending_response = Some(message);
129136
trace!("poll_read: waiting on socket");
130-
return Ok(());
137+
return Ok(None);
131138
}
132139
}
133140
}
@@ -225,18 +232,26 @@ impl Connection {
225232
Err(e) => Err(Error::from(e)),
226233
}
227234
}
235+
236+
pub fn poll_message(&mut self) -> Poll<Option<AsyncMessage>, Error> {
237+
let message = self.poll_read()?;
238+
let want_flush = self.poll_write()?;
239+
if want_flush {
240+
self.poll_flush()?;
241+
}
242+
match message {
243+
Some(message) => Ok(Async::Ready(Some(message))),
244+
None => self.poll_shutdown().map(|r| r.map(|()| None)),
245+
}
246+
}
228247
}
229248

230249
impl Future for Connection {
231250
type Item = ();
232251
type Error = Error;
233252

234253
fn poll(&mut self) -> Poll<(), Error> {
235-
self.poll_read()?;
236-
let want_flush = self.poll_write()?;
237-
if want_flush {
238-
self.poll_flush()?;
239-
}
240-
self.poll_shutdown()
254+
while let Some(_) = try_ready!(self.poll_message()) {}
255+
Ok(Async::Ready(()))
241256
}
242257
}

tokio-postgres/tests/test.rs

+59-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
extern crate env_logger;
2-
extern crate futures;
32
extern crate tokio;
43
extern crate tokio_postgres;
54

5+
#[macro_use]
6+
extern crate futures;
7+
#[macro_use]
8+
extern crate log;
9+
10+
use futures::future;
11+
use futures::sync::mpsc;
612
use std::time::{Duration, Instant};
713
use tokio::prelude::*;
814
use tokio::runtime::current_thread::Runtime;
915
use tokio::timer::Delay;
1016
use tokio_postgres::error::SqlState;
1117
use tokio_postgres::types::{Kind, Type};
12-
use tokio_postgres::TlsMode;
18+
use tokio_postgres::{AsyncMessage, TlsMode};
1319

1420
fn smoke_test(url: &str) {
1521
let _ = env_logger::try_init();
@@ -447,3 +453,54 @@ fn custom_simple() {
447453
assert_eq!("hstore", ty.name());
448454
assert_eq!(&Kind::Simple, ty.kind());
449455
}
456+
457+
#[test]
458+
fn notifications() {
459+
let _ = env_logger::try_init();
460+
let mut runtime = Runtime::new().unwrap();
461+
462+
let handshake = tokio_postgres::connect(
463+
"postgres://postgres@localhost:5433".parse().unwrap(),
464+
TlsMode::None,
465+
);
466+
let (mut client, mut connection) = runtime.block_on(handshake).unwrap();
467+
468+
let (tx, rx) = mpsc::unbounded();
469+
let connection = future::poll_fn(move || {
470+
while let Some(message) = try_ready!(connection.poll_message().map_err(|e| panic!("{}", e)))
471+
{
472+
if let AsyncMessage::Notification(notification) = message {
473+
debug!("received {}", notification.payload);
474+
tx.unbounded_send(notification).unwrap();
475+
}
476+
}
477+
478+
Ok(Async::Ready(()))
479+
});
480+
runtime.handle().spawn(connection).unwrap();
481+
482+
let listen = client.prepare("LISTEN test_notifications");
483+
let listen = runtime.block_on(listen).unwrap();
484+
runtime.block_on(client.execute(&listen, &[])).unwrap();
485+
drop(listen); // FIXME
486+
487+
let notify = client.prepare("NOTIFY test_notifications, 'hello'");
488+
let notify = runtime.block_on(notify).unwrap();
489+
runtime.block_on(client.execute(&notify, &[])).unwrap();
490+
drop(notify); // FIXME
491+
492+
let notify = client.prepare("NOTIFY test_notifications, 'world'");
493+
let notify = runtime.block_on(notify).unwrap();
494+
runtime.block_on(client.execute(&notify, &[])).unwrap();
495+
drop(notify); // FIXME
496+
497+
drop(client);
498+
runtime.run().unwrap();
499+
500+
let notifications = rx.collect().wait().unwrap();
501+
assert_eq!(notifications.len(), 2);
502+
assert_eq!(notifications[0].channel, "test_notifications");
503+
assert_eq!(notifications[0].payload, "hello");
504+
assert_eq!(notifications[1].channel, "test_notifications");
505+
assert_eq!(notifications[1].payload, "world");
506+
}

0 commit comments

Comments
 (0)