1+ //! Utilities for working with the PostgreSQL binary copy format.
2+
3+ use crate :: types:: { FromSql , IsNull , ToSql , Type , WrongType } ;
4+ use crate :: { slice_iter, CopyInSink , CopyOutStream , Error } ;
15use byteorder:: { BigEndian , ByteOrder } ;
26use bytes:: { Buf , BufMut , Bytes , BytesMut } ;
37use futures:: { ready, SinkExt , Stream } ;
48use pin_project_lite:: pin_project;
59use std:: convert:: TryFrom ;
6- use std:: error :: Error ;
10+ use std:: io ;
711use std:: io:: Cursor ;
812use std:: ops:: Range ;
913use std:: pin:: Pin ;
1014use std:: sync:: Arc ;
1115use std:: task:: { Context , Poll } ;
12- use tokio_postgres:: types:: { FromSql , IsNull , ToSql , Type , WrongType } ;
13- use tokio_postgres:: { CopyInSink , CopyOutStream } ;
14-
15- #[ cfg( test) ]
16- mod test;
1716
1817const MAGIC : & [ u8 ] = b"PGCOPY\n \xff \r \n \0 " ;
1918const HEADER_LEN : usize = MAGIC . len ( ) + 4 + 4 ;
2019
2120pin_project ! {
21+ /// A type which serializes rows into the PostgreSQL binary copy format.
22+ ///
23+ /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
2224 pub struct BinaryCopyInWriter {
2325 #[ pin]
2426 sink: CopyInSink <Bytes >,
@@ -28,10 +30,10 @@ pin_project! {
2830}
2931
3032impl BinaryCopyInWriter {
33+ /// Creates a new writer which will write rows of the provided types to the provided sink.
3134 pub fn new ( sink : CopyInSink < Bytes > , types : & [ Type ] ) -> BinaryCopyInWriter {
3235 let mut buf = BytesMut :: new ( ) ;
33- buf. reserve ( HEADER_LEN ) ;
34- buf. put_slice ( MAGIC ) ; // magic
36+ buf. put_slice ( MAGIC ) ;
3537 buf. put_i32 ( 0 ) ; // flags
3638 buf. put_i32 ( 0 ) ; // header extension
3739
@@ -42,19 +44,23 @@ impl BinaryCopyInWriter {
4244 }
4345 }
4446
45- pub async fn write (
46- self : Pin < & mut Self > ,
47- values : & [ & ( dyn ToSql + Send ) ] ,
48- ) -> Result < ( ) , Box < dyn Error + Sync + Send > > {
49- self . write_raw ( values. iter ( ) . cloned ( ) ) . await
47+ /// Writes a single row.
48+ ///
49+ /// # Panics
50+ ///
51+ /// Panics if the number of values provided does not match the number expected.
52+ pub async fn write ( self : Pin < & mut Self > , values : & [ & ( dyn ToSql + Sync ) ] ) -> Result < ( ) , Error > {
53+ self . write_raw ( slice_iter ( values) ) . await
5054 }
5155
52- pub async fn write_raw < ' a , I > (
53- self : Pin < & mut Self > ,
54- values : I ,
55- ) -> Result < ( ) , Box < dyn Error + Sync + Send > >
56+ /// A maximally-flexible version of `write`.
57+ ///
58+ /// # Panics
59+ ///
60+ /// Panics if the number of values provided does not match the number expected.
61+ pub async fn write_raw < ' a , I > ( self : Pin < & mut Self > , values : I ) -> Result < ( ) , Error >
5662 where
57- I : IntoIterator < Item = & ' a ( dyn ToSql + Send ) > ,
63+ I : IntoIterator < Item = & ' a dyn ToSql > ,
5864 I :: IntoIter : ExactSizeIterator ,
5965 {
6066 let mut this = self . project ( ) ;
@@ -69,12 +75,16 @@ impl BinaryCopyInWriter {
6975
7076 this. buf . put_i16 ( this. types . len ( ) as i16 ) ;
7177
72- for ( value, type_) in values. zip ( this. types ) {
78+ for ( i , ( value, type_) ) in values. zip ( this. types ) . enumerate ( ) {
7379 let idx = this. buf . len ( ) ;
7480 this. buf . put_i32 ( 0 ) ;
75- let len = match value. to_sql_checked ( type_, this. buf ) ? {
81+ let len = match value
82+ . to_sql_checked ( type_, this. buf )
83+ . map_err ( |e| Error :: to_sql ( e, i) ) ?
84+ {
7685 IsNull :: Yes => -1 ,
77- IsNull :: No => i32:: try_from ( this. buf . len ( ) - idx - 4 ) ?,
86+ IsNull :: No => i32:: try_from ( this. buf . len ( ) - idx - 4 )
87+ . map_err ( |e| Error :: encode ( io:: Error :: new ( io:: ErrorKind :: InvalidInput , e) ) ) ?,
7888 } ;
7989 BigEndian :: write_i32 ( & mut this. buf [ idx..] , len) ;
8090 }
@@ -86,7 +96,10 @@ impl BinaryCopyInWriter {
8696 Ok ( ( ) )
8797 }
8898
89- pub async fn finish ( self : Pin < & mut Self > ) -> Result < u64 , tokio_postgres:: Error > {
99+ /// Completes the copy, returning the number of rows added.
100+ ///
101+ /// This method *must* be used to complete the copy process. If it is not, the copy will be aborted.
102+ pub async fn finish ( self : Pin < & mut Self > ) -> Result < u64 , Error > {
90103 let mut this = self . project ( ) ;
91104
92105 this. buf . put_i16 ( -1 ) ;
@@ -100,6 +113,7 @@ struct Header {
100113}
101114
102115pin_project ! {
116+ /// A stream of rows deserialized from the PostgreSQL binary copy format.
103117 pub struct BinaryCopyOutStream {
104118 #[ pin]
105119 stream: CopyOutStream ,
@@ -109,7 +123,8 @@ pin_project! {
109123}
110124
111125impl BinaryCopyOutStream {
112- pub fn new ( types : & [ Type ] , stream : CopyOutStream ) -> BinaryCopyOutStream {
126+ /// Creates a stream from a raw copy out stream and the types of the columns being returned.
127+ pub fn new ( stream : CopyOutStream , types : & [ Type ] ) -> BinaryCopyOutStream {
113128 BinaryCopyOutStream {
114129 stream,
115130 types : Arc :: new ( types. to_vec ( ) ) ,
@@ -119,15 +134,15 @@ impl BinaryCopyOutStream {
119134}
120135
121136impl Stream for BinaryCopyOutStream {
122- type Item = Result < BinaryCopyOutRow , Box < dyn Error + Sync + Send > > ;
137+ type Item = Result < BinaryCopyOutRow , Error > ;
123138
124139 fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
125140 let this = self . project ( ) ;
126141
127142 let chunk = match ready ! ( this. stream. poll_next( cx) ) {
128143 Some ( Ok ( chunk) ) => chunk,
129- Some ( Err ( e) ) => return Poll :: Ready ( Some ( Err ( e. into ( ) ) ) ) ,
130- None => return Poll :: Ready ( Some ( Err ( "unexpected EOF" . into ( ) ) ) ) ,
144+ Some ( Err ( e) ) => return Poll :: Ready ( Some ( Err ( e) ) ) ,
145+ None => return Poll :: Ready ( Some ( Err ( Error :: closed ( ) ) ) ) ,
131146 } ;
132147 let mut chunk = Cursor :: new ( chunk) ;
133148
@@ -136,7 +151,10 @@ impl Stream for BinaryCopyOutStream {
136151 None => {
137152 check_remaining ( & chunk, HEADER_LEN ) ?;
138153 if & chunk. bytes ( ) [ ..MAGIC . len ( ) ] != MAGIC {
139- return Poll :: Ready ( Some ( Err ( "invalid magic value" . into ( ) ) ) ) ;
154+ return Poll :: Ready ( Some ( Err ( Error :: parse ( io:: Error :: new (
155+ io:: ErrorKind :: InvalidData ,
156+ "invalid magic value" ,
157+ ) ) ) ) ) ;
140158 }
141159 chunk. advance ( MAGIC . len ( ) ) ;
142160
@@ -162,7 +180,10 @@ impl Stream for BinaryCopyOutStream {
162180 len += 1 ;
163181 }
164182 if len as usize != this. types . len ( ) {
165- return Poll :: Ready ( Some ( Err ( "unexpected tuple size" . into ( ) ) ) ) ;
183+ return Poll :: Ready ( Some ( Err ( Error :: parse ( io:: Error :: new (
184+ io:: ErrorKind :: InvalidInput ,
185+ format ! ( "expected {} values but got {}" , this. types. len( ) , len) ,
186+ ) ) ) ) ) ;
166187 }
167188
168189 let mut ranges = vec ! [ ] ;
@@ -188,36 +209,55 @@ impl Stream for BinaryCopyOutStream {
188209 }
189210}
190211
191- fn check_remaining ( buf : & impl Buf , len : usize ) -> Result < ( ) , Box < dyn Error + Sync + Send > > {
212+ fn check_remaining ( buf : & Cursor < Bytes > , len : usize ) -> Result < ( ) , Error > {
192213 if buf. remaining ( ) < len {
193- Err ( "unexpected EOF" . into ( ) )
214+ Err ( Error :: parse ( io:: Error :: new (
215+ io:: ErrorKind :: UnexpectedEof ,
216+ "unexpected EOF" ,
217+ ) ) )
194218 } else {
195219 Ok ( ( ) )
196220 }
197221}
198222
223+ /// A row of data parsed from a binary copy out stream.
199224pub struct BinaryCopyOutRow {
200225 buf : Bytes ,
201226 ranges : Vec < Option < Range < usize > > > ,
202227 types : Arc < Vec < Type > > ,
203228}
204229
205230impl BinaryCopyOutRow {
206- pub fn try_get < ' a , T > ( & ' a self , idx : usize ) -> Result < T , Box < dyn Error + Sync + Send > >
231+ /// Like `get`, but returns a `Result` rather than panicking.
232+ pub fn try_get < ' a , T > ( & ' a self , idx : usize ) -> Result < T , Error >
207233 where
208234 T : FromSql < ' a > ,
209235 {
210- let type_ = & self . types [ idx] ;
236+ let type_ = match self . types . get ( idx) {
237+ Some ( type_) => type_,
238+ None => return Err ( Error :: column ( idx. to_string ( ) ) ) ,
239+ } ;
240+
211241 if !T :: accepts ( type_) {
212- return Err ( WrongType :: new :: < T > ( type_. clone ( ) ) . into ( ) ) ;
242+ return Err ( Error :: from_sql (
243+ Box :: new ( WrongType :: new :: < T > ( type_. clone ( ) ) ) ,
244+ idx,
245+ ) ) ;
213246 }
214247
215- match & self . ranges [ idx] {
216- Some ( range) => T :: from_sql ( type_, & self . buf [ range. clone ( ) ] ) . map_err ( Into :: into) ,
217- None => T :: from_sql_null ( type_) . map_err ( Into :: into) ,
218- }
248+ let r = match & self . ranges [ idx] {
249+ Some ( range) => T :: from_sql ( type_, & self . buf [ range. clone ( ) ] ) ,
250+ None => T :: from_sql_null ( type_) ,
251+ } ;
252+
253+ r. map_err ( |e| Error :: from_sql ( e, idx) )
219254 }
220255
256+ /// Deserializes a value from the row.
257+ ///
258+ /// # Panics
259+ ///
260+ /// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
221261 pub fn get < ' a , T > ( & ' a self , idx : usize ) -> T
222262 where
223263 T : FromSql < ' a > ,
0 commit comments