Skip to content

Commit ed05675

Browse files
committed
derive generic FromSql/ToSql
1 parent 5433118 commit ed05675

File tree

4 files changed

+116
-8
lines changed

4 files changed

+116
-8
lines changed

postgres-derive-test/src/composites.rs

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use crate::test_type;
1+
use crate::{test_type, test_type_asymmetric};
22
use postgres::{Client, NoTls};
3-
use postgres_types::{FromSql, ToSql, WrongType};
3+
use postgres_types::{FromSql, FromSqlOwned, ToSql, WrongType};
44
use std::error::Error;
55

66
#[test]
@@ -238,3 +238,68 @@ fn raw_ident_field() {
238238

239239
test_type(&mut conn, "inventory_item", &[(item, "ROW('foo')")]);
240240
}
241+
242+
#[test]
243+
fn generics() {
244+
#[derive(FromSql, Debug, PartialEq)]
245+
struct InventoryItem<T: FromSqlOwned, U>
246+
where
247+
U: FromSqlOwned,
248+
{
249+
name: String,
250+
supplier_id: T,
251+
price: Option<U>,
252+
}
253+
254+
// doesn't make sense to implement derived FromSql on a type with borrows
255+
#[derive(ToSql, Debug, PartialEq)]
256+
#[postgres(name = "InventoryItem")]
257+
struct InventoryItemRef<'a, T: 'a + ToSql, U>
258+
where
259+
U: 'a + ToSql,
260+
{
261+
name: &'a str,
262+
supplier_id: &'a T,
263+
price: Option<&'a U>,
264+
}
265+
266+
const NAME: &str = "foobar";
267+
const SUPPLIER_ID: i32 = 100;
268+
const PRICE: f64 = 15.50;
269+
270+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
271+
conn.batch_execute(
272+
"CREATE TYPE pg_temp.\"InventoryItem\" AS (
273+
name TEXT,
274+
supplier_id INT,
275+
price DOUBLE PRECISION
276+
);",
277+
)
278+
.unwrap();
279+
280+
let item = InventoryItemRef {
281+
name: NAME,
282+
supplier_id: &SUPPLIER_ID,
283+
price: Some(&PRICE),
284+
};
285+
286+
let item_null = InventoryItemRef {
287+
name: NAME,
288+
supplier_id: &SUPPLIER_ID,
289+
price: None,
290+
};
291+
292+
test_type_asymmetric(
293+
&mut conn,
294+
"\"InventoryItem\"",
295+
&[
296+
(item, "ROW('foobar', 100, 15.50)"),
297+
(item_null, "ROW('foobar', 100, NULL)"),
298+
],
299+
|t: &InventoryItemRef<i32, f64>, f: &InventoryItem<i32, f64>| {
300+
t.name == f.name.as_str()
301+
&& t.supplier_id == &f.supplier_id
302+
&& t.price == f.price.as_ref()
303+
},
304+
);
305+
}

postgres-derive-test/src/lib.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,30 @@ where
2727
}
2828
}
2929

30+
pub fn test_type_asymmetric<T, F, S, C>(
31+
conn: &mut Client,
32+
sql_type: &str,
33+
checks: &[(T, S)],
34+
cmp: C,
35+
) where
36+
T: ToSql + Sync,
37+
F: FromSqlOwned,
38+
S: fmt::Display,
39+
C: Fn(&T, &F) -> bool,
40+
{
41+
for &(ref val, ref repr) in checks.iter() {
42+
let stmt = conn
43+
.prepare(&*format!("SELECT {}::{}", *repr, sql_type))
44+
.unwrap();
45+
let result: F = conn.query_one(&stmt, &[]).unwrap().get(0);
46+
assert!(cmp(val, &result));
47+
48+
let stmt = conn.prepare(&*format!("SELECT $1::{}", sql_type)).unwrap();
49+
let result: F = conn.query_one(&stmt, &[val]).unwrap().get(0);
50+
assert!(cmp(val, &result));
51+
}
52+
}
53+
3054
#[test]
3155
fn compile_fail() {
3256
trybuild::TestCases::new().compile_fail("src/compile-fail/*.rs");

postgres-derive/src/fromsql.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
use proc_macro2::TokenStream;
1+
use proc_macro2::{Span, TokenStream};
22
use quote::{format_ident, quote};
33
use std::iter;
4-
use syn::{Data, DataStruct, DeriveInput, Error, Fields, Ident};
4+
use syn::{
5+
Data, DataStruct, DeriveInput, Error, Fields, GenericParam, Generics, Ident, Lifetime,
6+
LifetimeDef,
7+
};
58

69
use crate::accepts;
710
use crate::composites::Field;
@@ -86,10 +89,13 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
8689
};
8790

8891
let ident = &input.ident;
92+
let (generics, lifetime) = build_generics(&input.generics);
93+
let (impl_generics, _, _) = generics.split_for_impl();
94+
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
8995
let out = quote! {
90-
impl<'a> postgres_types::FromSql<'a> for #ident {
91-
fn from_sql(_type: &postgres_types::Type, buf: &'a [u8])
92-
-> std::result::Result<#ident,
96+
impl#impl_generics postgres_types::FromSql<#lifetime> for #ident#ty_generics #where_clause {
97+
fn from_sql(_type: &postgres_types::Type, buf: &#lifetime [u8])
98+
-> std::result::Result<#ident#ty_generics,
9399
std::boxed::Box<dyn std::error::Error +
94100
std::marker::Sync +
95101
std::marker::Send>> {
@@ -200,3 +206,15 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
200206
})
201207
}
202208
}
209+
210+
fn build_generics(source: &Generics) -> (Generics, Lifetime) {
211+
let mut out = source.to_owned();
212+
// don't worry about lifetime name collisions, it doesn't make sense to derive FromSql on a struct with a lifetime
213+
let lifetime = Lifetime::new("'a", Span::call_site());
214+
out.params.insert(
215+
0,
216+
GenericParam::Lifetime(LifetimeDef::new(lifetime.to_owned())),
217+
);
218+
219+
(out, lifetime)
220+
}

postgres-derive/src/tosql.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
8282
};
8383

8484
let ident = &input.ident;
85+
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
8586
let out = quote! {
86-
impl postgres_types::ToSql for #ident {
87+
impl#impl_generics postgres_types::ToSql for #ident#ty_generics #where_clause {
8788
fn to_sql(&self,
8889
_type: &postgres_types::Type,
8990
buf: &mut postgres_types::private::BytesMut)

0 commit comments

Comments
 (0)