diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 45e5c4074..22afd5ac4 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -22,6 +22,7 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -93,6 +94,7 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), + CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -190,6 +192,16 @@ impl Message { storage, }) } + COPY_BOTH_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyBothResponse(CopyBothResponseBody { + format, + len, + storage, + }) + } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; @@ -524,6 +536,27 @@ impl CopyOutResponseBody { } } +pub struct CopyBothResponseBody { + format: u8, + len: u16, + storage: Bytes, +} + +impl CopyBothResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + pub struct DataRowBody { storage: Bytes, len: u16, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 4a099d941..f74406ad1 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,6 +1,7 @@ use crate::codec::BackendMessages; use crate::config::{Host, SslMode}; use crate::connection::{Request, RequestMessages}; +use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; @@ -11,8 +12,9 @@ use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, - Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, + copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, + CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, + TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -449,6 +451,15 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } + /// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy + /// data. + pub async fn copy_both_simple(&self, query: &str) -> Result, Error> + where + T: Buf + 'static + Send, + { + copy_both::copy_both_simple(self.inner(), query).await + } + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index b6805f76c..e98056dcc 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -1,4 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::copy_both::CopyBothReceiver; use crate::copy_in::CopyInReceiver; use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; @@ -21,6 +22,7 @@ use tokio_util::codec::Framed; pub enum RequestMessages { Single(FrontendMessage), CopyIn(CopyInReceiver), + CopyBoth(CopyBothReceiver), } pub struct Request { @@ -259,6 +261,24 @@ where .map_err(Error::io)?; self.pending_request = Some(RequestMessages::CopyIn(receiver)); } + RequestMessages::CopyBoth(mut receiver) => { + let message = match receiver.poll_next_unpin(cx) { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => { + trace!("poll_write: finished copy_both request"); + continue; + } + Poll::Pending => { + trace!("poll_write: waiting on copy_both stream"); + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + return Ok(true); + } + }; + Pin::new(&mut self.stream) + .start_send(message) + .map_err(Error::io)?; + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + } } } } diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs new file mode 100644 index 000000000..3eaaae1e1 --- /dev/null +++ b/tokio-postgres/src/copy_both.rs @@ -0,0 +1,280 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{simple_query, Error}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures::channel::{mpsc, oneshot}; +use futures::{ready, Sink, SinkExt, Stream, StreamExt}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::CopyData; +use std::marker::{PhantomData, PhantomPinned}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// This stream is consumed by the Connection once we enter the CopyBoth sub protocol. When the +/// user drops their DuplexStream, CopyBothReceiver will automatically send a CopyDone or Sync +/// message to the backend +pub struct CopyBothReceiver { + message_receiver: mpsc::Receiver, + error_receiver: oneshot::Receiver<()>, + done: bool, +} + +impl CopyBothReceiver { + pub(crate) fn new( + message_receiver: mpsc::Receiver, + error_receiver: oneshot::Receiver<()>, + ) -> CopyBothReceiver { + CopyBothReceiver { + message_receiver, + error_receiver, + done: false, + } + } +} + +impl Stream for CopyBothReceiver { + type Item = FrontendMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.done { + return Poll::Ready(None); + } + + // If an error was received from the backend we have to send a Sync message and not send + // any other messages + if let Ok(Some(_)) = self.error_receiver.try_recv() { + self.message_receiver.close(); + self.done = true; + let mut buf = BytesMut::new(); + frontend::sync(&mut buf); + return Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))); + } + + match ready!(self.message_receiver.poll_next_unpin(cx)) { + Some(message) => Poll::Ready(Some(message)), + None => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + } + } +} + +/// The state machine of a CopyBothDuplex +/// +/// ```ignore +/// CopyBoth +/// / \ +/// v v +/// CopyOut CopyIn +/// \ / +/// v v +/// ReadingCopy +/// | +/// v +/// Reading +/// ``` +#[derive(Clone, Copy)] +enum DuplexState { + /// Initial state where CopyData messages can go in both directions + CopyBoth, + /// The server->client stream is closed and we're in CopyIn mode + CopyIn, + /// The client->server stream is closed and we're in CopyOut mode + CopyOut, + /// Both directions are closed and we're waiting for the first CommandComplete message + ReadingCopy, + /// Both directions are closed and we're waiting for the final CommandComplete message + Reading, +} + +pin_project! { + /// A duplex stream for consuming streaming replication data. + /// + /// The stream side *must* be consumed even if not required in order to process the messages + /// coming from the server. If it is not, `Sink::close` may hang forever waiting for the stream + /// messages to be consumed. + /// + /// The copy should be explicitly completed via the `Sink::close` method to ensure all data is + /// flushed and errors are received. + pub struct CopyBothDuplex { + #[pin] + message_sender: mpsc::Sender, + #[pin] + error_sender: Option>, + responses: Responses, + buf: BytesMut, + state: DuplexState, + #[pin] + _p: PhantomPinned, + _p2: PhantomData, + } +} + +impl Stream for CopyBothDuplex { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use DuplexState::*; + + let mut this = self.project(); + let unexpected = Poll::Ready(Some(Err(Error::unexpected_message()))); + + loop { + match ready!(this.responses.poll_next(cx)?) { + Message::CopyData(body) => { + return match this.state { + CopyBoth | CopyOut => Poll::Ready(Some(Ok(body.into_bytes()))), + _ => unexpected, + }; + } + // The server->client stream is done + Message::CopyDone => { + *this.state = match this.state { + CopyBoth => CopyIn, + CopyOut => ReadingCopy, + _ => return unexpected, + }; + } + // The server indicated an error, terminate both sides + Message::ErrorResponse(error) => { + *this.state = match this.state { + ReadingCopy | Reading => return unexpected, + _ => ReadingCopy, + }; + // Indicate to CopyBothReceiver to produce a Sync message instead of CopyDone + let _ = this.error_sender.take().unwrap().send(()); + return Poll::Ready(Some(Err(Error::db(error)))); + } + Message::CommandComplete(_) => { + match this.state { + ReadingCopy => *this.state = Reading, + Reading => return Poll::Ready(None), + _ => return unexpected, + }; + } + _ => return unexpected, + } + } + } +} + +impl Sink for CopyBothDuplex +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .message_sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + let this = self.project(); + + let data: Box = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.message_sender + .start_send(FrontendMessage::CopyData(data)) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.message_sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.message_sender + .as_mut() + .start_send(FrontendMessage::CopyData(data)) + .map_err(|_| Error::closed())?; + } + + this.message_sender + .poll_flush(cx) + .map_err(|_| Error::closed()) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use DuplexState::*; + + loop { + match self.state { + CopyBoth | CopyIn => { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + this.message_sender.disconnect(); + *this.state = match this.state { + CopyBoth => CopyOut, + CopyIn => ReadingCopy, + _ => unreachable!(), + }; + } + _ => return Poll::Ready(Ok(())), + } + } + } +} + +pub async fn copy_both_simple( + client: &InnerClient, + query: &str, +) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy both query {}", query); + + let buf = simple_query::encode(client, query)?; + + let (error_sender, error_receiver) = oneshot::channel(); + let (mut message_sender, message_receiver) = mpsc::channel(1); + + let receiver = CopyBothReceiver::new(message_receiver, error_receiver); + let mut responses = client.send(RequestMessages::CopyBoth(receiver))?; + + message_sender + .send(FrontendMessage::Raw(buf)) + .await + .map_err(|_| Error::closed())?; + + match responses.next().await? { + Message::CopyBothResponse(_) => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(CopyBothDuplex { + error_sender: Some(error_sender), + message_sender, + responses, + buf: BytesMut::new(), + state: DuplexState::CopyBoth, + _p: PhantomPinned, + _p2: PhantomData, + }) +} diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 77713bb11..0cb4a348b 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -155,6 +155,7 @@ mod connect_raw; mod connect_socket; mod connect_tls; mod connection; +mod copy_both; mod copy_in; mod copy_out; pub mod error; diff --git a/tokio-postgres/tests/test/copy_both.rs b/tokio-postgres/tests/test/copy_both.rs new file mode 100644 index 000000000..60d72c650 --- /dev/null +++ b/tokio-postgres/tests/test/copy_both.rs @@ -0,0 +1,91 @@ +use futures::{SinkExt, StreamExt, TryStreamExt}; +use std::future; +use tokio_postgres::NoTls; +use tokio_postgres::SimpleQueryMessage::Row; + +// https://stackoverflow.com/a/35907071 +fn contains(haystack: &[u8], needle: &[u8]) -> bool { + haystack + .windows(needle.len()) + .any(|window| window == needle) +} + +#[tokio::test] +async fn copy_both() { + copy_both_impl(true).await; + copy_both_impl(false).await; +} + +async fn copy_both_impl(graceful: bool) { + // form SQL connection + let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database"; + let (client, connection) = tokio_postgres::connect(conninfo, NoTls).await.unwrap(); + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + client + .simple_query("DROP TABLE IF EXISTS foo") + .await + .unwrap(); + client + .simple_query("CREATE TABLE foo (i text)") + .await + .unwrap(); + + let slot = client + .simple_query("CREATE_REPLICATION_SLOT slot TEMPORARY LOGICAL \"test_decoding\"") + .await + .unwrap(); + + let lsn = match &slot[0] { + Row(row) => row.get("consistent_point").unwrap(), + _ => panic!("unexpeced query message"), + }; + + client + .simple_query("INSERT INTO foo VALUES ('fde61ce315faac78b17c')") + .await + .unwrap(); + + let query = format!("START_REPLICATION SLOT slot LOGICAL {}", lsn); + let duplex_stream = client + .copy_both_simple::(&query) + .await + .unwrap(); + + let (mut sink, stream) = duplex_stream.split(); + + // We can only do some basic sanity checking of the raw stream. We filter for messages starting + // with 'w', which are XLogData messages, and check that they contain the transaction we expect + let actual: Vec<_> = stream + .try_filter(|buf| future::ready(buf[0] == b'w')) + .take(3) + .try_collect() + .await + .unwrap(); + let expected = &[ + "BEGIN", + "table public.foo: INSERT: i[text]:'fde61ce315faac78b17c'", + "COMMIT", + ]; + + for (msg, fragment) in actual.iter().zip(expected) { + assert!(contains(msg, fragment.as_bytes())); + } + + // Test that we can return into query processing mode + if graceful { + sink.close().await.unwrap(); + } else { + drop(sink); + } + + let q = client.simple_query("SELECT 'works' AS t").await.unwrap(); + match &q[0] { + Row(row) => assert_eq!(Some("works"), row.get("t")), + _ => panic!("unexpeced query message"), + }; +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index c367dbea3..e5ac8bcbd 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -17,6 +17,7 @@ use tokio_postgres::{ }; mod binary_copy; +mod copy_both; mod parse; #[cfg(feature = "runtime")] mod runtime;