Skip to content

Commit bf8b335

Browse files
committedDec 9, 2019
Move binary copy stuff directly into main crate
1 parent 0c84ed9 commit bf8b335

File tree

7 files changed

+91
-76
lines changed

7 files changed

+91
-76
lines changed
 

‎Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ members = [
99
"postgres-protocol",
1010
"postgres-types",
1111
"tokio-postgres",
12-
"tokio-postgres-binary-copy",
1312
]
1413

1514
[profile.release]

‎tokio-postgres-binary-copy/Cargo.toml

-16
This file was deleted.

‎tokio-postgres/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ with-uuid-0_8 = ["postgres-types/with-uuid-0_8"]
3737

3838
[dependencies]
3939
bytes = "0.5"
40+
byteorder = "1.0"
4041
fallible-iterator = "0.2"
4142
futures = "0.3"
4243
log = "0.4"

‎tokio-postgres-binary-copy/src/lib.rs renamed to ‎tokio-postgres/src/binary_copy.rs

+77-37
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
1+
//! Utilities for working with the PostgreSQL binary copy format.
2+
3+
use crate::types::{FromSql, IsNull, ToSql, Type, WrongType};
4+
use crate::{slice_iter, CopyInSink, CopyOutStream, Error};
15
use byteorder::{BigEndian, ByteOrder};
26
use bytes::{Buf, BufMut, Bytes, BytesMut};
37
use futures::{ready, SinkExt, Stream};
48
use pin_project_lite::pin_project;
59
use std::convert::TryFrom;
6-
use std::error::Error;
10+
use std::io;
711
use std::io::Cursor;
812
use std::ops::Range;
913
use std::pin::Pin;
1014
use std::sync::Arc;
1115
use std::task::{Context, Poll};
12-
use tokio_postgres::types::{FromSql, IsNull, ToSql, Type, WrongType};
13-
use tokio_postgres::{CopyInSink, CopyOutStream};
14-
15-
#[cfg(test)]
16-
mod test;
1716

1817
const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
1918
const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
2019

2120
pin_project! {
21+
/// A type which serializes rows into the PostgreSQL binary copy format.
22+
///
23+
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
2224
pub struct BinaryCopyInWriter {
2325
#[pin]
2426
sink: CopyInSink<Bytes>,
@@ -28,10 +30,10 @@ pin_project! {
2830
}
2931

3032
impl BinaryCopyInWriter {
33+
/// Creates a new writer which will write rows of the provided types to the provided sink.
3134
pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
3235
let mut buf = BytesMut::new();
33-
buf.reserve(HEADER_LEN);
34-
buf.put_slice(MAGIC); // magic
36+
buf.put_slice(MAGIC);
3537
buf.put_i32(0); // flags
3638
buf.put_i32(0); // header extension
3739

@@ -42,19 +44,23 @@ impl BinaryCopyInWriter {
4244
}
4345
}
4446

45-
pub async fn write(
46-
self: Pin<&mut Self>,
47-
values: &[&(dyn ToSql + Send)],
48-
) -> Result<(), Box<dyn Error + Sync + Send>> {
49-
self.write_raw(values.iter().cloned()).await
47+
/// Writes a single row.
48+
///
49+
/// # Panics
50+
///
51+
/// Panics if the number of values provided does not match the number expected.
52+
pub async fn write(self: Pin<&mut Self>, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> {
53+
self.write_raw(slice_iter(values)).await
5054
}
5155

52-
pub async fn write_raw<'a, I>(
53-
self: Pin<&mut Self>,
54-
values: I,
55-
) -> Result<(), Box<dyn Error + Sync + Send>>
56+
/// A maximally-flexible version of `write`.
57+
///
58+
/// # Panics
59+
///
60+
/// Panics if the number of values provided does not match the number expected.
61+
pub async fn write_raw<'a, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error>
5662
where
57-
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
63+
I: IntoIterator<Item = &'a dyn ToSql>,
5864
I::IntoIter: ExactSizeIterator,
5965
{
6066
let mut this = self.project();
@@ -69,12 +75,16 @@ impl BinaryCopyInWriter {
6975

7076
this.buf.put_i16(this.types.len() as i16);
7177

72-
for (value, type_) in values.zip(this.types) {
78+
for (i, (value, type_)) in values.zip(this.types).enumerate() {
7379
let idx = this.buf.len();
7480
this.buf.put_i32(0);
75-
let len = match value.to_sql_checked(type_, this.buf)? {
81+
let len = match value
82+
.to_sql_checked(type_, this.buf)
83+
.map_err(|e| Error::to_sql(e, i))?
84+
{
7685
IsNull::Yes => -1,
77-
IsNull::No => i32::try_from(this.buf.len() - idx - 4)?,
86+
IsNull::No => i32::try_from(this.buf.len() - idx - 4)
87+
.map_err(|e| Error::encode(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
7888
};
7989
BigEndian::write_i32(&mut this.buf[idx..], len);
8090
}
@@ -86,7 +96,10 @@ impl BinaryCopyInWriter {
8696
Ok(())
8797
}
8898

89-
pub async fn finish(self: Pin<&mut Self>) -> Result<u64, tokio_postgres::Error> {
99+
/// Completes the copy, returning the number of rows added.
100+
///
101+
/// This method *must* be used to complete the copy process. If it is not, the copy will be aborted.
102+
pub async fn finish(self: Pin<&mut Self>) -> Result<u64, Error> {
90103
let mut this = self.project();
91104

92105
this.buf.put_i16(-1);
@@ -100,6 +113,7 @@ struct Header {
100113
}
101114

102115
pin_project! {
116+
/// A stream of rows deserialized from the PostgreSQL binary copy format.
103117
pub struct BinaryCopyOutStream {
104118
#[pin]
105119
stream: CopyOutStream,
@@ -109,7 +123,8 @@ pin_project! {
109123
}
110124

111125
impl BinaryCopyOutStream {
112-
pub fn new(types: &[Type], stream: CopyOutStream) -> BinaryCopyOutStream {
126+
/// Creates a stream from a raw copy out stream and the types of the columns being returned.
127+
pub fn new(stream: CopyOutStream, types: &[Type]) -> BinaryCopyOutStream {
113128
BinaryCopyOutStream {
114129
stream,
115130
types: Arc::new(types.to_vec()),
@@ -119,15 +134,15 @@ impl BinaryCopyOutStream {
119134
}
120135

121136
impl Stream for BinaryCopyOutStream {
122-
type Item = Result<BinaryCopyOutRow, Box<dyn Error + Sync + Send>>;
137+
type Item = Result<BinaryCopyOutRow, Error>;
123138

124139
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
125140
let this = self.project();
126141

127142
let chunk = match ready!(this.stream.poll_next(cx)) {
128143
Some(Ok(chunk)) => chunk,
129-
Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
130-
None => return Poll::Ready(Some(Err("unexpected EOF".into()))),
144+
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
145+
None => return Poll::Ready(Some(Err(Error::closed()))),
131146
};
132147
let mut chunk = Cursor::new(chunk);
133148

@@ -136,7 +151,10 @@ impl Stream for BinaryCopyOutStream {
136151
None => {
137152
check_remaining(&chunk, HEADER_LEN)?;
138153
if &chunk.bytes()[..MAGIC.len()] != MAGIC {
139-
return Poll::Ready(Some(Err("invalid magic value".into())));
154+
return Poll::Ready(Some(Err(Error::parse(io::Error::new(
155+
io::ErrorKind::InvalidData,
156+
"invalid magic value",
157+
)))));
140158
}
141159
chunk.advance(MAGIC.len());
142160

@@ -162,7 +180,10 @@ impl Stream for BinaryCopyOutStream {
162180
len += 1;
163181
}
164182
if len as usize != this.types.len() {
165-
return Poll::Ready(Some(Err("unexpected tuple size".into())));
183+
return Poll::Ready(Some(Err(Error::parse(io::Error::new(
184+
io::ErrorKind::InvalidInput,
185+
format!("expected {} values but got {}", this.types.len(), len),
186+
)))));
166187
}
167188

168189
let mut ranges = vec![];
@@ -188,36 +209,55 @@ impl Stream for BinaryCopyOutStream {
188209
}
189210
}
190211

191-
fn check_remaining(buf: &impl Buf, len: usize) -> Result<(), Box<dyn Error + Sync + Send>> {
212+
fn check_remaining(buf: &Cursor<Bytes>, len: usize) -> Result<(), Error> {
192213
if buf.remaining() < len {
193-
Err("unexpected EOF".into())
214+
Err(Error::parse(io::Error::new(
215+
io::ErrorKind::UnexpectedEof,
216+
"unexpected EOF",
217+
)))
194218
} else {
195219
Ok(())
196220
}
197221
}
198222

223+
/// A row of data parsed from a binary copy out stream.
199224
pub struct BinaryCopyOutRow {
200225
buf: Bytes,
201226
ranges: Vec<Option<Range<usize>>>,
202227
types: Arc<Vec<Type>>,
203228
}
204229

205230
impl BinaryCopyOutRow {
206-
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Box<dyn Error + Sync + Send>>
231+
/// Like `get`, but returns a `Result` rather than panicking.
232+
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Error>
207233
where
208234
T: FromSql<'a>,
209235
{
210-
let type_ = &self.types[idx];
236+
let type_ = match self.types.get(idx) {
237+
Some(type_) => type_,
238+
None => return Err(Error::column(idx.to_string())),
239+
};
240+
211241
if !T::accepts(type_) {
212-
return Err(WrongType::new::<T>(type_.clone()).into());
242+
return Err(Error::from_sql(
243+
Box::new(WrongType::new::<T>(type_.clone())),
244+
idx,
245+
));
213246
}
214247

215-
match &self.ranges[idx] {
216-
Some(range) => T::from_sql(type_, &self.buf[range.clone()]).map_err(Into::into),
217-
None => T::from_sql_null(type_).map_err(Into::into),
218-
}
248+
let r = match &self.ranges[idx] {
249+
Some(range) => T::from_sql(type_, &self.buf[range.clone()]),
250+
None => T::from_sql_null(type_),
251+
};
252+
253+
r.map_err(|e| Error::from_sql(e, idx))
219254
}
220255

256+
/// Deserializes a value from the row.
257+
///
258+
/// # Panics
259+
///
260+
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
221261
pub fn get<'a, T>(&'a self, idx: usize) -> T
222262
where
223263
T: FromSql<'a>,

‎tokio-postgres/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ pub use crate::to_statement::ToStatement;
123123
pub use crate::transaction::Transaction;
124124
use crate::types::ToSql;
125125

126+
pub mod binary_copy;
126127
mod bind;
127128
#[cfg(feature = "runtime")]
128129
mod cancel_query;

‎tokio-postgres-binary-copy/src/test.rs renamed to ‎tokio-postgres/tests/test/binary_copy.rs

+11-22
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,11 @@
1-
use crate::{BinaryCopyInWriter, BinaryCopyOutStream};
1+
use crate::connect;
22
use futures::{pin_mut, TryStreamExt};
3+
use tokio_postgres::binary_copy::{BinaryCopyInWriter, BinaryCopyOutStream};
34
use tokio_postgres::types::Type;
4-
use tokio_postgres::{Client, NoTls};
5-
6-
async fn connect() -> Client {
7-
let (client, connection) =
8-
tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls)
9-
.await
10-
.unwrap();
11-
tokio::spawn(async {
12-
connection.await.unwrap();
13-
});
14-
client
15-
}
165

176
#[tokio::test]
187
async fn write_basic() {
19-
let client = connect().await;
8+
let client = connect("user=postgres").await;
209

2110
client
2211
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
@@ -50,7 +39,7 @@ async fn write_basic() {
5039

5140
#[tokio::test]
5241
async fn write_many_rows() {
53-
let client = connect().await;
42+
let client = connect("user=postgres").await;
5443

5544
client
5645
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
@@ -86,7 +75,7 @@ async fn write_many_rows() {
8675

8776
#[tokio::test]
8877
async fn write_big_rows() {
89-
let client = connect().await;
78+
let client = connect("user=postgres").await;
9079

9180
client
9281
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
@@ -122,7 +111,7 @@ async fn write_big_rows() {
122111

123112
#[tokio::test]
124113
async fn read_basic() {
125-
let client = connect().await;
114+
let client = connect("user=postgres").await;
126115

127116
client
128117
.batch_execute(
@@ -138,7 +127,7 @@ async fn read_basic() {
138127
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
139128
.await
140129
.unwrap();
141-
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream)
130+
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::TEXT])
142131
.try_collect::<Vec<_>>()
143132
.await
144133
.unwrap();
@@ -152,7 +141,7 @@ async fn read_basic() {
152141

153142
#[tokio::test]
154143
async fn read_many_rows() {
155-
let client = connect().await;
144+
let client = connect("user=postgres").await;
156145

157146
client
158147
.batch_execute(
@@ -167,7 +156,7 @@ async fn read_many_rows() {
167156
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
168157
.await
169158
.unwrap();
170-
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream)
159+
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::TEXT])
171160
.try_collect::<Vec<_>>()
172161
.await
173162
.unwrap();
@@ -181,7 +170,7 @@ async fn read_many_rows() {
181170

182171
#[tokio::test]
183172
async fn read_big_rows() {
184-
let client = connect().await;
173+
let client = connect("user=postgres").await;
185174

186175
client
187176
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
@@ -201,7 +190,7 @@ async fn read_big_rows() {
201190
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
202191
.await
203192
.unwrap();
204-
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::BYTEA], stream)
193+
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::BYTEA])
205194
.try_collect::<Vec<_>>()
206195
.await
207196
.unwrap();

‎tokio-postgres/tests/test/main.rs

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use tokio_postgres::tls::{NoTls, NoTlsStream};
1414
use tokio_postgres::types::{Kind, Type};
1515
use tokio_postgres::{AsyncMessage, Client, Config, Connection, Error, SimpleQueryMessage};
1616

17+
mod binary_copy;
1718
mod parse;
1819
#[cfg(feature = "runtime")]
1920
mod runtime;

0 commit comments

Comments
 (0)
Please sign in to comment.