Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for zero-copy FromSql derive #1070

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions postgres-derive-test/src/composites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,32 @@ fn generics() {
},
);
}

#[test]
fn struct_with_borrowed_fields() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "item")]
struct Item<'a, 'b: 'a> {
name: &'a str,
data: &'b [u8],
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.item AS (
name TEXT,
data BYTEA
);",
)
.unwrap();

let item = Item {
name: "foobar",
data: b"12345",
};

let row = conn.query_one("SELECT $1::item", &[&item]).unwrap();
let result: Item<'_, '_> = row.get(0);
assert_eq!(item.name, result.name);
assert_eq!(item.data, result.data);
}
35 changes: 35 additions & 0 deletions postgres-derive-test/src/transparent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,38 @@ fn round_trip() {
UserId(123)
);
}

#[test]
fn struct_with_reference() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(transparent)]
struct UserName<'a>(&'a str);

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();

let user_name = "tester";
let row = conn
.query_one("SELECT $1", &[&UserName(user_name)])
.unwrap();
let result: UserName<'_> = row.get(0);
assert_eq!(user_name, result.0);
}

#[test]
fn nested_struct_with_reference() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(transparent)]
struct Inner<'a>(&'a str);

#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(transparent)]
struct UserName<'a>(#[postgres(borrow)] Inner<'a>);

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();

let user_name = "tester";
let inner = Inner(user_name);
let row = conn.query_one("SELECT $1", &[&UserName(inner)]).unwrap();
let result: UserName<'_> = row.get(0);
assert_eq!(user_name, result.0 .0);
}
4 changes: 2 additions & 2 deletions postgres-derive/src/accepts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use quote::quote;
use std::iter;
use syn::Ident;

use crate::composites::Field;
use crate::composites::NamedField;
use crate::enums::Variant;

pub fn transparent_body(field: &syn::Field) -> TokenStream {
Expand Down Expand Up @@ -66,7 +66,7 @@ pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> Toke
}
}

pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream {
pub fn composite_body(name: &str, trait_: &str, fields: &[NamedField]) -> TokenStream {
let num_fields = fields.len();
let trait_ = Ident::new(trait_, Span::call_site());
let traits = iter::repeat(&trait_);
Expand Down
21 changes: 13 additions & 8 deletions postgres-derive/src/composites.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
use lifetimes::extract_borrowed_lifetimes;
use proc_macro2::Span;
use syn::{
punctuated::Punctuated, Error, GenericParam, Generics, Ident, Path, PathSegment, Type,
TypeParamBound,
punctuated::Punctuated, Error, GenericParam, Generics, Ident, Lifetime, Path, PathSegment,
Type, TypeParamBound,
};

use crate::{case::RenameRule, overrides::Overrides};
use crate::{case::RenameRule, lifetimes, overrides::Overrides};

pub struct Field {
pub struct NamedField {
pub name: String,
pub ident: Ident,
pub type_: Type,
pub borrowed_lifetimes: Vec<Lifetime>,
}

impl Field {
pub fn parse(raw: &syn::Field, rename_all: Option<RenameRule>) -> Result<Field, Error> {
impl NamedField {
pub fn parse(raw: &syn::Field, rename_all: Option<RenameRule>) -> Result<NamedField, Error> {
let overrides = Overrides::extract(&raw.attrs, false)?;
let ident = raw.ident.as_ref().unwrap().clone();

// field level name override takes precendence over container level rename_all override
let borrowed_lifetimes = extract_borrowed_lifetimes(raw, &overrides);

// field level name override takes precedence over container level rename_all override
let name = match overrides.name {
Some(n) => n,
None => {
Expand All @@ -31,10 +35,11 @@ impl Field {
}
};

Ok(Field {
Ok(NamedField {
name,
ident,
type_: raw.ty.clone(),
borrowed_lifetimes,
})
}
}
Expand Down
44 changes: 30 additions & 14 deletions postgres-derive/src/fromsql.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use std::collections::BTreeSet;
use std::iter;
use std::iter::FromIterator;
use syn::{
punctuated::Punctuated, token, AngleBracketedGenericArguments, Data, DataStruct, DeriveInput,
Error, Fields, GenericArgument, GenericParam, Generics, Ident, Lifetime, PathArguments,
Expand All @@ -9,10 +11,11 @@ use syn::{
use syn::{LifetimeParam, TraitBound, TraitBoundModifier, TypeParamBound};

use crate::accepts;
use crate::composites::Field;
use crate::composites::NamedField;
use crate::composites::{append_generic_bound, new_derive_path};
use crate::enums::Variant;
use crate::overrides::Overrides;
use crate::transparent::UnnamedField;

pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
let overrides = Overrides::extract(&input.attrs, true)?;
Expand All @@ -29,16 +32,18 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
.clone()
.unwrap_or_else(|| input.ident.to_string());

let (accepts_body, to_sql_body) = if overrides.transparent {
let (accepts_body, to_sql_body, borrowed_lifetimes) = if overrides.transparent {
match input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(ref fields),
..
}) if fields.unnamed.len() == 1 => {
let field = fields.unnamed.first().unwrap();
let parsed_field = UnnamedField::parse(field)?;
(
accepts::transparent_body(field),
transparent_body(&input.ident, field),
parsed_field.borrowed_lifetimes,
)
}
_ => {
Expand All @@ -59,6 +64,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
(
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
vec![],
)
}
_ => {
Expand All @@ -79,16 +85,19 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
(
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
vec![],
)
}
Data::Struct(DataStruct {
fields: Fields::Unnamed(ref fields),
..
}) if fields.unnamed.len() == 1 => {
let field = fields.unnamed.first().unwrap();
let parsed_field = UnnamedField::parse(field)?;
(
domain_accepts_body(&name, field),
domain_body(&input.ident, field),
parsed_field.borrowed_lifetimes,
)
}
Data::Struct(DataStruct {
Expand All @@ -98,11 +107,16 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
let fields = fields
.named
.iter()
.map(|field| Field::parse(field, overrides.rename_all))
.map(|field| NamedField::parse(field, overrides.rename_all))
.collect::<Result<Vec<_>, _>>()?;
let borrowed_lifetimes: Vec<_> = fields
.iter()
.flat_map(|f| f.borrowed_lifetimes.to_owned())
.collect();
(
accepts::composite_body(&name, "FromSql", &fields),
composite_body(&input.ident, &fields),
borrowed_lifetimes
)
}
_ => {
Expand All @@ -115,7 +129,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
};

let ident = &input.ident;
let (generics, lifetime) = build_generics(&input.generics);
let (generics, lifetime) = build_generics(&input.generics, borrowed_lifetimes);
let (impl_generics, _, _) = generics.split_for_impl();
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
let out = quote! {
Expand Down Expand Up @@ -183,7 +197,7 @@ fn domain_body(ident: &Ident, field: &syn::Field) -> TokenStream {
}
}

fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
fn composite_body(ident: &Ident, fields: &[NamedField]) -> TokenStream {
let temp_vars = &fields
.iter()
.map(|f| format_ident!("__{}", f.ident))
Expand Down Expand Up @@ -233,16 +247,18 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
}
}

fn build_generics(source: &Generics) -> (Generics, Lifetime) {
// don't worry about lifetime name collisions, it doesn't make sense to derive FromSql on a struct with a lifetime
let lifetime = Lifetime::new("'a", Span::call_site());

fn build_generics(
source: &Generics,
borrowed_lifetimes: Vec<Lifetime>,
) -> (Generics, Lifetime) {
// This is the same parent lifetime name serde uses
let lifetime = Lifetime::new("'de", Span::call_site());
// Sort lifetimes for deterministic code-gen
JosephMoniz marked this conversation as resolved.
Show resolved Hide resolved
let sorted_lifetimes = BTreeSet::from_iter(borrowed_lifetimes);
let mut lifetime_param = LifetimeParam::new(lifetime.to_owned());
lifetime_param.bounds.extend(sorted_lifetimes);
let mut out = append_generic_bound(source.to_owned(), &new_fromsql_bound(&lifetime));
out.params.insert(
0,
GenericParam::Lifetime(LifetimeParam::new(lifetime.to_owned())),
);

out.params.insert(0, GenericParam::Lifetime(lifetime_param));
(out, lifetime)
}

Expand Down
2 changes: 2 additions & 0 deletions postgres-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ mod case;
mod composites;
mod enums;
mod fromsql;
mod lifetimes;
mod overrides;
mod tosql;
mod transparent;

#[proc_macro_derive(ToSql, attributes(postgres))]
pub fn derive_tosql(input: TokenStream) -> TokenStream {
Expand Down
37 changes: 37 additions & 0 deletions postgres-derive/src/lifetimes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use crate::overrides::Overrides;
use syn::{AngleBracketedGenericArguments, GenericArgument, Lifetime, PathArguments, Type};

pub(crate) fn extract_borrowed_lifetimes(raw: &syn::Field, overrides: &Overrides) -> Vec<Lifetime> {
let mut borrowed_lifetimes = vec![];

// If the field is a reference, it's lifetime should be implicitly borrowed. Serde does
// the same thing
if let Type::Reference(ref_type) = &raw.ty {
let lifetime = &ref_type.lifetime;
if !borrowed_lifetimes.contains(lifetime.as_ref().unwrap()) {
borrowed_lifetimes.push(lifetime.to_owned().unwrap());
}
}

// Borrow all generic lifetimes of fields marked with #[postgres(borrow)]
if overrides.borrows {
if let Type::Path(type_path) = &raw.ty {
for segment in &type_path.path.segments {
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args, ..
}) = &segment.arguments
{
for arg in args.iter() {
if let GenericArgument::Lifetime(lifetime) = arg {
if !borrowed_lifetimes.contains(lifetime) {
borrowed_lifetimes.push(lifetime.to_owned());
}
}
}
}
}
}
}

borrowed_lifetimes
}
10 changes: 10 additions & 0 deletions postgres-derive/src/overrides.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub struct Overrides {
pub rename_all: Option<RenameRule>,
pub transparent: bool,
pub allow_mismatch: bool,
pub borrows: bool,
}

impl Overrides {
Expand All @@ -17,6 +18,7 @@ impl Overrides {
rename_all: None,
transparent: false,
allow_mismatch: false,
borrows: false,
};

for attr in attrs {
Expand Down Expand Up @@ -92,6 +94,14 @@ impl Overrides {
));
}
overrides.allow_mismatch = true;
} else if path.is_ident("borrow") {
if container_attr {
return Err(Error::new_spanned(
path,
"#[postgres(borrow)] is a field attribute",
));
}
overrides.borrows = true;
} else {
return Err(Error::new_spanned(path, "unknown override"));
}
Expand Down
6 changes: 3 additions & 3 deletions postgres-derive/src/tosql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use syn::{
};

use crate::accepts;
use crate::composites::Field;
use crate::composites::NamedField;
use crate::composites::{append_generic_bound, new_derive_path};
use crate::enums::Variant;
use crate::overrides::Overrides;
Expand Down Expand Up @@ -92,7 +92,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
let fields = fields
.named
.iter()
.map(|field| Field::parse(field, overrides.rename_all))
.map(|field| NamedField::parse(field, overrides.rename_all))
.collect::<Result<Vec<_>, _>>()?;
(
accepts::composite_body(&name, "ToSql", &fields),
Expand Down Expand Up @@ -168,7 +168,7 @@ fn domain_body() -> TokenStream {
}
}

fn composite_body(fields: &[Field]) -> TokenStream {
fn composite_body(fields: &[NamedField]) -> TokenStream {
let field_names = fields.iter().map(|f| &f.name);
let field_idents = fields.iter().map(|f| &f.ident);

Expand Down
16 changes: 16 additions & 0 deletions postgres-derive/src/transparent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use crate::lifetimes;
use crate::overrides::Overrides;
use lifetimes::extract_borrowed_lifetimes;
use syn::{Error, Lifetime};

pub struct UnnamedField {
pub borrowed_lifetimes: Vec<Lifetime>,
}

impl UnnamedField {
pub fn parse(raw: &syn::Field) -> Result<UnnamedField, Error> {
let overrides = Overrides::extract(&raw.attrs, false)?;
let borrowed_lifetimes = extract_borrowed_lifetimes(raw, &overrides);
Ok(UnnamedField { borrowed_lifetimes })
}
}