Skip to content

Commit f744467

Browse files
committed
Add replication stream/sink splitting
There's some care to be taken when using this, more than the sort of splitting that we have in something like rustls-split. That's been described in more detail in the comment above `CopyBothDuplex::split`
1 parent 9eb0dbf commit f744467

File tree

2 files changed

+159
-2
lines changed

2 files changed

+159
-2
lines changed

Diff for: tokio-postgres/src/copy_both.rs

+158-1
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,18 @@ enum SinkState {
6969
}
7070

7171
pin_project! {
72-
/// A sink for `COPY ... FROM STDIN` query data.
72+
/// A sink & stream for `CopyBoth` replication messages
7373
///
7474
/// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is
7575
/// not, the copy will be aborted.
76+
///
77+
/// The duplex can be split into the separate sink and stream with the [`split`] method. When
78+
/// using this, they must be re-joined before finishing in order to properly complete the copy.
79+
///
80+
/// Both the implementation of [`Stream`] and [`Sink`] provide access to the bytes wrapped
81+
/// inside of the `CopyData` wrapper.
82+
///
83+
/// [`split`]: Self::split
7684
pub struct CopyBothDuplex<T> {
7785
#[pin]
7886
sender: mpsc::Sender<CopyBothMessage>,
@@ -146,6 +154,53 @@ where
146154
pub async fn finish(mut self: Pin<&mut Self>) -> Result<u64, Error> {
147155
future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await
148156
}
157+
158+
/// Splits the streams into distinct [`Sink`] and [`Stream`] components
159+
///
160+
/// Please note that there must be an eventual call to [`join`] the two components in order to
161+
/// properly close the connection with [`finish`]; no corresponding method exists for the two
162+
/// halves alone.
163+
///
164+
/// [`join`]: Self::join
165+
/// [`finish`]: Self::finish
166+
pub fn split(self) -> (Sender<T>, Receiver) {
167+
let send = Sender {
168+
sender: self.sender,
169+
buf: self.buf,
170+
state: self.state,
171+
marker: PhantomData,
172+
closed: false,
173+
};
174+
175+
let recv = Receiver {
176+
responses: self.responses,
177+
};
178+
179+
(send, recv)
180+
}
181+
182+
/// Joins the two halves of a `CopyBothDuplex` after a call to [`split`]
183+
///
184+
/// Note: We do not check that the sender and recevier originated from the same
185+
/// [`CopyBothDuplex`]. If they did not, unexpected behavior *will* occur.
186+
///
187+
/// ## Panics
188+
///
189+
/// If the sender has already been closed, this function will panic.
190+
///
191+
/// [`split`]: Self::split
192+
pub fn join(send: Sender<T>, recv: Receiver) -> Self {
193+
assert!(!send.closed);
194+
195+
CopyBothDuplex {
196+
sender: send.sender,
197+
responses: recv.responses,
198+
buf: send.buf,
199+
state: send.state,
200+
_p: PhantomPinned,
201+
_p2: PhantomData,
202+
}
203+
}
149204
}
150205

151206
impl<T> Stream for CopyBothDuplex<T> {
@@ -157,6 +212,7 @@ impl<T> Stream for CopyBothDuplex<T> {
157212
match ready!(this.responses.poll_next(cx)?) {
158213
Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))),
159214
Message::CopyDone => Poll::Ready(None),
215+
Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))),
160216
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
161217
}
162218
}
@@ -220,6 +276,107 @@ where
220276
}
221277
}
222278

279+
pin_project! {
280+
/// The receiving half of a [`CopyBothDuplex`]
281+
///
282+
/// Receiving the next message is done through the [`Stream`] implementation.
283+
pub struct Receiver {
284+
responses: Responses,
285+
}
286+
}
287+
288+
pin_project! {
289+
/// The sending half of a [`CopyBothDuplex`]
290+
///
291+
/// Sending each message is done through the [`Sink`] implementation.
292+
pub struct Sender<T> {
293+
#[pin]
294+
sender: mpsc::Sender<CopyBothMessage>,
295+
buf: BytesMut,
296+
state: SinkState,
297+
marker: PhantomData<T>,
298+
// True iff the sink has been closed. Causes further operations to panic.
299+
closed: bool,
300+
}
301+
}
302+
303+
impl Stream for Receiver {
304+
type Item = Result<Bytes, Error>;
305+
306+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
307+
let this = self.project();
308+
309+
match ready!(this.responses.poll_next(cx)?) {
310+
Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))),
311+
Message::CopyDone => Poll::Ready(None),
312+
Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))),
313+
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
314+
}
315+
}
316+
}
317+
318+
impl<T> Sink<T> for Sender<T>
319+
where
320+
T: Buf + 'static + Send,
321+
{
322+
type Error = Error;
323+
324+
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
325+
self.project()
326+
.sender
327+
.poll_ready(cx)
328+
.map_err(|_| Error::closed())
329+
}
330+
331+
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
332+
assert!(!self.closed);
333+
334+
let this = self.project();
335+
336+
let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
337+
if this.buf.is_empty() {
338+
Box::new(item)
339+
} else {
340+
Box::new(this.buf.split().freeze().chain(item))
341+
}
342+
} else {
343+
this.buf.put(item);
344+
if this.buf.len() > 4096 {
345+
Box::new(this.buf.split().freeze())
346+
} else {
347+
return Ok(());
348+
}
349+
};
350+
351+
let data = CopyData::new(data).map_err(Error::encode)?;
352+
this.sender
353+
.start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data)))
354+
.map_err(|_| Error::closed())
355+
}
356+
357+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
358+
let mut this = self.project();
359+
360+
if !this.buf.is_empty() {
361+
ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
362+
let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
363+
let data = CopyData::new(data).map_err(Error::encode)?;
364+
this.sender
365+
.as_mut()
366+
.start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data)))
367+
.map_err(|_| Error::closed())?;
368+
}
369+
370+
this.sender.poll_flush(cx).map_err(|_| Error::closed())
371+
}
372+
373+
// Closing the sink "normally" will just abort the copy.
374+
fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
375+
self.closed = true;
376+
Poll::Ready(Ok(()))
377+
}
378+
}
379+
223380
pub async fn copy_both_simple<T>(
224381
client: &InnerClient,
225382
query: &str,

Diff for: tokio-postgres/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ pub use crate::cancel_token::CancelToken;
119119
pub use crate::client::Client;
120120
pub use crate::config::Config;
121121
pub use crate::connection::Connection;
122-
pub use crate::copy_both::CopyBothDuplex;
122+
pub use crate::copy_both::{CopyBothDuplex, Receiver as CopyBothStream, Sender as CopyBothSink};
123123
pub use crate::copy_in::CopyInSink;
124124
pub use crate::copy_out::CopyOutStream;
125125
use crate::error::DbError;

0 commit comments

Comments
 (0)