Skip to content

Commit 7cdb68b

Browse files
authored
support flatten attribute in FromRow macro (#1959)
* support flatten attribute in FromRow macro * added docs for flatten FromRow attribute
1 parent bc3e705 commit 7cdb68b

File tree

4 files changed

+117
-36
lines changed

4 files changed

+117
-36
lines changed

sqlx-core/src/from_row.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,36 @@ use crate::row::Row;
9292
/// will set the value of the field `location` to the default value of `Option<String>`,
9393
/// which is `None`.
9494
///
95+
/// ### `flatten`
96+
///
97+
/// If you want to handle a field that implements [`FromRow`],
98+
/// you can use the `flatten` attribute to specify that you want
99+
/// it to use [`FromRow`] for parsing rather than the usual method.
100+
/// For example:
101+
///
102+
/// ```rust,ignore
103+
/// #[derive(sqlx::FromRow)]
104+
/// struct Address {
105+
/// country: String,
106+
/// city: String,
107+
/// road: String,
108+
/// }
109+
///
110+
/// #[derive(sqlx::FromRow)]
111+
/// struct User {
112+
/// id: i32,
113+
/// name: String,
114+
/// #[sqlx(flatten)]
115+
/// address: Address,
116+
/// }
117+
/// ```
118+
/// Given a query such as:
119+
///
120+
/// ```sql
121+
/// SELECT id, name, country, city, road FROM users;
122+
/// ```
123+
///
124+
/// This field is compatible with the `default` attribute.
95125
pub trait FromRow<'r, R: Row>: Sized {
96126
fn from_row(row: &'r R) -> Result<Self, Error>;
97127
}

sqlx-macros/src/derives/attributes.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ pub struct SqlxContainerAttributes {
7070
pub struct SqlxChildAttributes {
7171
pub rename: Option<String>,
7272
pub default: bool,
73+
pub flatten: bool,
7374
}
7475

7576
pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
@@ -177,6 +178,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContai
177178
pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttributes> {
178179
let mut rename = None;
179180
let mut default = false;
181+
let mut flatten = false;
180182

181183
for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) {
182184
let meta = attr
@@ -193,6 +195,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
193195
..
194196
}) if path.is_ident("rename") => try_set!(rename, val.value(), value),
195197
Meta::Path(path) if path.is_ident("default") => default = true,
198+
Meta::Path(path) if path.is_ident("flatten") => flatten = true,
196199
u => fail!(u, "unexpected attribute"),
197200
},
198201
u => fail!(u, "unexpected attribute"),
@@ -201,7 +204,11 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
201204
}
202205
}
203206

204-
Ok(SqlxChildAttributes { rename, default })
207+
Ok(SqlxChildAttributes {
208+
rename,
209+
default,
210+
flatten,
211+
})
205212
}
206213

207214
pub fn check_transparent_attributes(

sqlx-macros/src/derives/row.rs

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use proc_macro2::{Span, TokenStream};
22
use quote::quote;
33
use syn::{
4-
parse_quote, punctuated::Punctuated, token::Comma, Data, DataStruct, DeriveInput, Field,
4+
parse_quote, punctuated::Punctuated, token::Comma, Data, DataStruct, DeriveInput, Expr, Field,
55
Fields, FieldsNamed, FieldsUnnamed, Lifetime, Stmt,
66
};
77

@@ -63,46 +63,49 @@ fn expand_derive_from_row_struct(
6363

6464
predicates.push(parse_quote!(&#lifetime ::std::primitive::str: ::sqlx::ColumnIndex<R>));
6565

66-
for field in fields {
67-
let ty = &field.ty;
68-
69-
predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
70-
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));
71-
}
72-
73-
let (impl_generics, _, where_clause) = generics.split_for_impl();
74-
7566
let container_attributes = parse_container_attributes(&input.attrs)?;
7667

77-
let reads = fields.iter().filter_map(|field| -> Option<Stmt> {
78-
let id = &field.ident.as_ref()?;
79-
let attributes = parse_child_attributes(&field.attrs).unwrap();
80-
let id_s = attributes
81-
.rename
82-
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
83-
.map(|s| match container_attributes.rename_all {
84-
Some(pattern) => rename_all(&s, pattern),
85-
None => s,
86-
})
87-
.unwrap();
88-
89-
let ty = &field.ty;
90-
91-
if attributes.default {
92-
Some(
93-
parse_quote!(let #id: #ty = row.try_get(#id_s).or_else(|e| match e {
68+
let reads: Vec<Stmt> = fields
69+
.iter()
70+
.filter_map(|field| -> Option<Stmt> {
71+
let id = &field.ident.as_ref()?;
72+
let attributes = parse_child_attributes(&field.attrs).unwrap();
73+
let ty = &field.ty;
74+
75+
let expr: Expr = if attributes.flatten {
76+
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
77+
parse_quote!(#ty::from_row(row))
78+
} else {
79+
predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
80+
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));
81+
82+
let id_s = attributes
83+
.rename
84+
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
85+
.map(|s| match container_attributes.rename_all {
86+
Some(pattern) => rename_all(&s, pattern),
87+
None => s,
88+
})
89+
.unwrap();
90+
parse_quote!(row.try_get(#id_s))
91+
};
92+
93+
if attributes.default {
94+
Some(parse_quote!(let #id: #ty = #expr.or_else(|e| match e {
9495
::sqlx::Error::ColumnNotFound(_) => {
9596
::std::result::Result::Ok(Default::default())
9697
},
9798
e => ::std::result::Result::Err(e)
98-
})?;),
99-
)
100-
} else {
101-
Some(parse_quote!(
102-
let #id: #ty = row.try_get(#id_s)?;
103-
))
104-
}
105-
});
99+
})?;))
100+
} else {
101+
Some(parse_quote!(
102+
let #id: #ty = #expr?;
103+
))
104+
}
105+
})
106+
.collect();
107+
108+
let (impl_generics, _, where_clause) = generics.split_for_impl();
106109

107110
let names = fields.iter().map(|field| &field.ident);
108111

tests/postgres/derives.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,3 +573,44 @@ async fn test_default() -> anyhow::Result<()> {
573573

574574
Ok(())
575575
}
576+
577+
#[cfg(feature = "macros")]
578+
#[sqlx_macros::test]
579+
async fn test_flatten() -> anyhow::Result<()> {
580+
#[derive(Debug, Default, sqlx::FromRow)]
581+
struct AccountDefault {
582+
default: Option<i32>,
583+
}
584+
585+
#[derive(Debug, sqlx::FromRow)]
586+
struct UserInfo {
587+
name: String,
588+
surname: String,
589+
}
590+
591+
#[derive(Debug, sqlx::FromRow)]
592+
struct AccountKeyword {
593+
id: i32,
594+
#[sqlx(flatten)]
595+
info: UserInfo,
596+
#[sqlx(default)]
597+
#[sqlx(flatten)]
598+
default: AccountDefault,
599+
}
600+
601+
let mut conn = new::<Postgres>().await?;
602+
603+
let account: AccountKeyword = sqlx::query_as(
604+
r#"SELECT * from (VALUES (1, 'foo', 'bar')) accounts("id", "name", "surname")"#,
605+
)
606+
.fetch_one(&mut conn)
607+
.await?;
608+
println!("{:?}", account);
609+
610+
assert_eq!(1, account.id);
611+
assert_eq!("foo", account.info.name);
612+
assert_eq!("bar", account.info.surname);
613+
assert_eq!(None, account.default.default);
614+
615+
Ok(())
616+
}

0 commit comments

Comments
 (0)