1
+ //! Utilities for working with the PostgreSQL binary copy format.
2
+
3
+ use crate :: types:: { FromSql , IsNull , ToSql , Type , WrongType } ;
4
+ use crate :: { slice_iter, CopyInSink , CopyOutStream , Error } ;
1
5
use byteorder:: { BigEndian , ByteOrder } ;
2
6
use bytes:: { Buf , BufMut , Bytes , BytesMut } ;
3
7
use futures:: { ready, SinkExt , Stream } ;
4
8
use pin_project_lite:: pin_project;
5
9
use std:: convert:: TryFrom ;
6
- use std:: error :: Error ;
10
+ use std:: io ;
7
11
use std:: io:: Cursor ;
8
12
use std:: ops:: Range ;
9
13
use std:: pin:: Pin ;
10
14
use std:: sync:: Arc ;
11
15
use std:: task:: { Context , Poll } ;
12
- use tokio_postgres:: types:: { FromSql , IsNull , ToSql , Type , WrongType } ;
13
- use tokio_postgres:: { CopyInSink , CopyOutStream } ;
14
-
15
- #[ cfg( test) ]
16
- mod test;
17
16
18
17
const MAGIC : & [ u8 ] = b"PGCOPY\n \xff \r \n \0 " ;
19
18
const HEADER_LEN : usize = MAGIC . len ( ) + 4 + 4 ;
20
19
21
20
pin_project ! {
21
+ /// A type which serializes rows into the PostgreSQL binary copy format.
22
+ ///
23
+ /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
22
24
pub struct BinaryCopyInWriter {
23
25
#[ pin]
24
26
sink: CopyInSink <Bytes >,
@@ -28,10 +30,10 @@ pin_project! {
28
30
}
29
31
30
32
impl BinaryCopyInWriter {
33
+ /// Creates a new writer which will write rows of the provided types to the provided sink.
31
34
pub fn new ( sink : CopyInSink < Bytes > , types : & [ Type ] ) -> BinaryCopyInWriter {
32
35
let mut buf = BytesMut :: new ( ) ;
33
- buf. reserve ( HEADER_LEN ) ;
34
- buf. put_slice ( MAGIC ) ; // magic
36
+ buf. put_slice ( MAGIC ) ;
35
37
buf. put_i32 ( 0 ) ; // flags
36
38
buf. put_i32 ( 0 ) ; // header extension
37
39
@@ -42,19 +44,23 @@ impl BinaryCopyInWriter {
42
44
}
43
45
}
44
46
45
- pub async fn write (
46
- self : Pin < & mut Self > ,
47
- values : & [ & ( dyn ToSql + Send ) ] ,
48
- ) -> Result < ( ) , Box < dyn Error + Sync + Send > > {
49
- self . write_raw ( values. iter ( ) . cloned ( ) ) . await
47
+ /// Writes a single row.
48
+ ///
49
+ /// # Panics
50
+ ///
51
+ /// Panics if the number of values provided does not match the number expected.
52
+ pub async fn write ( self : Pin < & mut Self > , values : & [ & ( dyn ToSql + Sync ) ] ) -> Result < ( ) , Error > {
53
+ self . write_raw ( slice_iter ( values) ) . await
50
54
}
51
55
52
- pub async fn write_raw < ' a , I > (
53
- self : Pin < & mut Self > ,
54
- values : I ,
55
- ) -> Result < ( ) , Box < dyn Error + Sync + Send > >
56
+ /// A maximally-flexible version of `write`.
57
+ ///
58
+ /// # Panics
59
+ ///
60
+ /// Panics if the number of values provided does not match the number expected.
61
+ pub async fn write_raw < ' a , I > ( self : Pin < & mut Self > , values : I ) -> Result < ( ) , Error >
56
62
where
57
- I : IntoIterator < Item = & ' a ( dyn ToSql + Send ) > ,
63
+ I : IntoIterator < Item = & ' a dyn ToSql > ,
58
64
I :: IntoIter : ExactSizeIterator ,
59
65
{
60
66
let mut this = self . project ( ) ;
@@ -69,12 +75,16 @@ impl BinaryCopyInWriter {
69
75
70
76
this. buf . put_i16 ( this. types . len ( ) as i16 ) ;
71
77
72
- for ( value, type_) in values. zip ( this. types ) {
78
+ for ( i , ( value, type_) ) in values. zip ( this. types ) . enumerate ( ) {
73
79
let idx = this. buf . len ( ) ;
74
80
this. buf . put_i32 ( 0 ) ;
75
- let len = match value. to_sql_checked ( type_, this. buf ) ? {
81
+ let len = match value
82
+ . to_sql_checked ( type_, this. buf )
83
+ . map_err ( |e| Error :: to_sql ( e, i) ) ?
84
+ {
76
85
IsNull :: Yes => -1 ,
77
- IsNull :: No => i32:: try_from ( this. buf . len ( ) - idx - 4 ) ?,
86
+ IsNull :: No => i32:: try_from ( this. buf . len ( ) - idx - 4 )
87
+ . map_err ( |e| Error :: encode ( io:: Error :: new ( io:: ErrorKind :: InvalidInput , e) ) ) ?,
78
88
} ;
79
89
BigEndian :: write_i32 ( & mut this. buf [ idx..] , len) ;
80
90
}
@@ -86,7 +96,10 @@ impl BinaryCopyInWriter {
86
96
Ok ( ( ) )
87
97
}
88
98
89
- pub async fn finish ( self : Pin < & mut Self > ) -> Result < u64 , tokio_postgres:: Error > {
99
+ /// Completes the copy, returning the number of rows added.
100
+ ///
101
+ /// This method *must* be used to complete the copy process. If it is not, the copy will be aborted.
102
+ pub async fn finish ( self : Pin < & mut Self > ) -> Result < u64 , Error > {
90
103
let mut this = self . project ( ) ;
91
104
92
105
this. buf . put_i16 ( -1 ) ;
@@ -100,6 +113,7 @@ struct Header {
100
113
}
101
114
102
115
pin_project ! {
116
+ /// A stream of rows deserialized from the PostgreSQL binary copy format.
103
117
pub struct BinaryCopyOutStream {
104
118
#[ pin]
105
119
stream: CopyOutStream ,
@@ -109,7 +123,8 @@ pin_project! {
109
123
}
110
124
111
125
impl BinaryCopyOutStream {
112
- pub fn new ( types : & [ Type ] , stream : CopyOutStream ) -> BinaryCopyOutStream {
126
+ /// Creates a stream from a raw copy out stream and the types of the columns being returned.
127
+ pub fn new ( stream : CopyOutStream , types : & [ Type ] ) -> BinaryCopyOutStream {
113
128
BinaryCopyOutStream {
114
129
stream,
115
130
types : Arc :: new ( types. to_vec ( ) ) ,
@@ -119,15 +134,15 @@ impl BinaryCopyOutStream {
119
134
}
120
135
121
136
impl Stream for BinaryCopyOutStream {
122
- type Item = Result < BinaryCopyOutRow , Box < dyn Error + Sync + Send > > ;
137
+ type Item = Result < BinaryCopyOutRow , Error > ;
123
138
124
139
fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
125
140
let this = self . project ( ) ;
126
141
127
142
let chunk = match ready ! ( this. stream. poll_next( cx) ) {
128
143
Some ( Ok ( chunk) ) => chunk,
129
- Some ( Err ( e) ) => return Poll :: Ready ( Some ( Err ( e. into ( ) ) ) ) ,
130
- None => return Poll :: Ready ( Some ( Err ( "unexpected EOF" . into ( ) ) ) ) ,
144
+ Some ( Err ( e) ) => return Poll :: Ready ( Some ( Err ( e) ) ) ,
145
+ None => return Poll :: Ready ( Some ( Err ( Error :: closed ( ) ) ) ) ,
131
146
} ;
132
147
let mut chunk = Cursor :: new ( chunk) ;
133
148
@@ -136,7 +151,10 @@ impl Stream for BinaryCopyOutStream {
136
151
None => {
137
152
check_remaining ( & chunk, HEADER_LEN ) ?;
138
153
if & chunk. bytes ( ) [ ..MAGIC . len ( ) ] != MAGIC {
139
- return Poll :: Ready ( Some ( Err ( "invalid magic value" . into ( ) ) ) ) ;
154
+ return Poll :: Ready ( Some ( Err ( Error :: parse ( io:: Error :: new (
155
+ io:: ErrorKind :: InvalidData ,
156
+ "invalid magic value" ,
157
+ ) ) ) ) ) ;
140
158
}
141
159
chunk. advance ( MAGIC . len ( ) ) ;
142
160
@@ -162,7 +180,10 @@ impl Stream for BinaryCopyOutStream {
162
180
len += 1 ;
163
181
}
164
182
if len as usize != this. types . len ( ) {
165
- return Poll :: Ready ( Some ( Err ( "unexpected tuple size" . into ( ) ) ) ) ;
183
+ return Poll :: Ready ( Some ( Err ( Error :: parse ( io:: Error :: new (
184
+ io:: ErrorKind :: InvalidInput ,
185
+ format ! ( "expected {} values but got {}" , this. types. len( ) , len) ,
186
+ ) ) ) ) ) ;
166
187
}
167
188
168
189
let mut ranges = vec ! [ ] ;
@@ -188,36 +209,55 @@ impl Stream for BinaryCopyOutStream {
188
209
}
189
210
}
190
211
191
- fn check_remaining ( buf : & impl Buf , len : usize ) -> Result < ( ) , Box < dyn Error + Sync + Send > > {
212
+ fn check_remaining ( buf : & Cursor < Bytes > , len : usize ) -> Result < ( ) , Error > {
192
213
if buf. remaining ( ) < len {
193
- Err ( "unexpected EOF" . into ( ) )
214
+ Err ( Error :: parse ( io:: Error :: new (
215
+ io:: ErrorKind :: UnexpectedEof ,
216
+ "unexpected EOF" ,
217
+ ) ) )
194
218
} else {
195
219
Ok ( ( ) )
196
220
}
197
221
}
198
222
223
+ /// A row of data parsed from a binary copy out stream.
199
224
pub struct BinaryCopyOutRow {
200
225
buf : Bytes ,
201
226
ranges : Vec < Option < Range < usize > > > ,
202
227
types : Arc < Vec < Type > > ,
203
228
}
204
229
205
230
impl BinaryCopyOutRow {
206
- pub fn try_get < ' a , T > ( & ' a self , idx : usize ) -> Result < T , Box < dyn Error + Sync + Send > >
231
+ /// Like `get`, but returns a `Result` rather than panicking.
232
+ pub fn try_get < ' a , T > ( & ' a self , idx : usize ) -> Result < T , Error >
207
233
where
208
234
T : FromSql < ' a > ,
209
235
{
210
- let type_ = & self . types [ idx] ;
236
+ let type_ = match self . types . get ( idx) {
237
+ Some ( type_) => type_,
238
+ None => return Err ( Error :: column ( idx. to_string ( ) ) ) ,
239
+ } ;
240
+
211
241
if !T :: accepts ( type_) {
212
- return Err ( WrongType :: new :: < T > ( type_. clone ( ) ) . into ( ) ) ;
242
+ return Err ( Error :: from_sql (
243
+ Box :: new ( WrongType :: new :: < T > ( type_. clone ( ) ) ) ,
244
+ idx,
245
+ ) ) ;
213
246
}
214
247
215
- match & self . ranges [ idx] {
216
- Some ( range) => T :: from_sql ( type_, & self . buf [ range. clone ( ) ] ) . map_err ( Into :: into) ,
217
- None => T :: from_sql_null ( type_) . map_err ( Into :: into) ,
218
- }
248
+ let r = match & self . ranges [ idx] {
249
+ Some ( range) => T :: from_sql ( type_, & self . buf [ range. clone ( ) ] ) ,
250
+ None => T :: from_sql_null ( type_) ,
251
+ } ;
252
+
253
+ r. map_err ( |e| Error :: from_sql ( e, idx) )
219
254
}
220
255
256
+ /// Deserializes a value from the row.
257
+ ///
258
+ /// # Panics
259
+ ///
260
+ /// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
221
261
pub fn get < ' a , T > ( & ' a self , idx : usize ) -> T
222
262
where
223
263
T : FromSql < ' a > ,
0 commit comments