Skip to content

Commit e89b715

Browse files
Allow pipelining with composed futures for Postgres
1 parent c73ded6 commit e89b715

File tree

1 file changed

+166
-1
lines changed

1 file changed

+166
-1
lines changed

src/pg/mod.rs

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,48 @@ const FAKE_OID: u32 = 0;
114114
/// # }
115115
/// ```
116116
///
117+
/// For more complex cases, an immutable reference to the connection need to be used:
118+
/// ```rust
119+
/// # include!("../doctest_setup.rs");
120+
/// use diesel_async::RunQueryDsl;
121+
///
122+
/// #
123+
/// # #[tokio::main(flavor = "current_thread")]
124+
/// # async fn main() {
125+
/// # run_test().await.unwrap();
126+
/// # }
127+
/// #
128+
/// # async fn run_test() -> QueryResult<()> {
129+
/// # use diesel::sql_types::{Text, Integer};
130+
/// # let conn = &mut establish_connection().await;
131+
/// #
132+
/// async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
133+
/// let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
134+
/// let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
135+
///
136+
/// futures_util::try_join!(f1, f2)
137+
/// }
138+
///
139+
/// async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
140+
/// let f3 = diesel::select(3_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
141+
/// let f4 = diesel::select(4_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
142+
///
143+
/// futures_util::try_join!(f3, f4)
144+
/// }
145+
///
146+
/// let f12 = fn12(&conn);
147+
/// let f34 = fn34(&conn);
148+
///
149+
/// let ((r1, r2), (r3, r4)) = futures_util::try_join!(f12, f34).unwrap();
150+
///
151+
/// assert_eq!(r1, 1);
152+
/// assert_eq!(r2, 2);
153+
/// assert_eq!(r3, 3);
154+
/// assert_eq!(r4, 4);
155+
/// # Ok(())
156+
/// # }
157+
/// ```
158+
///
117159
/// ## TLS
118160
///
119161
/// Connections created by [`AsyncPgConnection::establish`] do not support TLS.
@@ -136,6 +178,12 @@ pub struct AsyncPgConnection {
136178
}
137179

138180
impl SimpleAsyncConnection for AsyncPgConnection {
181+
async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
182+
SimpleAsyncConnection::batch_execute(&mut &*self, query).await
183+
}
184+
}
185+
186+
impl SimpleAsyncConnection for &AsyncPgConnection {
139187
async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
140188
self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new(
141189
query,
@@ -167,6 +215,38 @@ impl AsyncConnectionCore for AsyncPgConnection {
167215
type Row<'conn, 'query> = PgRow;
168216
type Backend = diesel::pg::Pg;
169217

218+
fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
219+
where
220+
T: AsQuery + 'query,
221+
T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
222+
{
223+
AsyncConnectionCore::load(&mut &*self, source)
224+
}
225+
226+
fn execute_returning_count<'conn, 'query, T>(
227+
&'conn mut self,
228+
source: T,
229+
) -> Self::ExecuteFuture<'conn, 'query>
230+
where
231+
T: QueryFragment<Self::Backend> + QueryId + 'query,
232+
{
233+
AsyncConnectionCore::execute_returning_count(&mut &*self, source)
234+
}
235+
}
236+
237+
impl AsyncConnectionCore for &AsyncPgConnection {
238+
type LoadFuture<'conn, 'query> =
239+
<AsyncPgConnection as AsyncConnectionCore>::LoadFuture<'conn, 'query>;
240+
241+
type ExecuteFuture<'conn, 'query> =
242+
<AsyncPgConnection as AsyncConnectionCore>::ExecuteFuture<'conn, 'query>;
243+
244+
type Stream<'conn, 'query> = <AsyncPgConnection as AsyncConnectionCore>::Stream<'conn, 'query>;
245+
246+
type Row<'conn, 'query> = <AsyncPgConnection as AsyncConnectionCore>::Row<'conn, 'query>;
247+
248+
type Backend = <AsyncPgConnection as AsyncConnectionCore>::Backend;
249+
170250
fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
171251
where
172252
T: AsQuery + 'query,
@@ -942,11 +1022,15 @@ mod tests {
9421022
use crate::run_query_dsl::RunQueryDsl;
9431023
use diesel::sql_types::Integer;
9441024
use diesel::IntoSql;
1025+
use futures_util::future::try_join;
1026+
use futures_util::try_join;
1027+
use scoped_futures::ScopedFutureExt;
9451028

9461029
#[tokio::test]
9471030
async fn pipelining() {
9481031
let database_url =
9491032
std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1033+
9501034
let mut conn = crate::AsyncPgConnection::establish(&database_url)
9511035
.await
9521036
.unwrap();
@@ -957,9 +1041,90 @@ mod tests {
9571041
let f1 = q1.get_result::<i32>(&mut conn);
9581042
let f2 = q2.get_result::<i32>(&mut conn);
9591043

960-
let (r1, r2) = futures_util::try_join!(f1, f2).unwrap();
1044+
let (r1, r2) = try_join!(f1, f2).unwrap();
1045+
1046+
assert_eq!(r1, 1);
1047+
assert_eq!(r2, 2);
1048+
}
1049+
1050+
#[tokio::test]
1051+
async fn pipelining_with_composed_futures() {
1052+
let database_url =
1053+
std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1054+
1055+
let conn = crate::AsyncPgConnection::establish(&database_url)
1056+
.await
1057+
.unwrap();
1058+
1059+
async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1060+
let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1061+
let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1062+
1063+
try_join!(f1, f2)
1064+
}
1065+
1066+
async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1067+
let f3 = diesel::select(3_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1068+
let f4 = diesel::select(4_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1069+
1070+
try_join!(f3, f4)
1071+
}
1072+
1073+
let f12 = fn12(&conn);
1074+
let f34 = fn34(&conn);
1075+
1076+
let ((r1, r2), (r3, r4)) = try_join!(f12, f34).unwrap();
9611077

9621078
assert_eq!(r1, 1);
9631079
assert_eq!(r2, 2);
1080+
assert_eq!(r3, 3);
1081+
assert_eq!(r4, 4);
1082+
}
1083+
1084+
#[tokio::test]
1085+
async fn pipelining_with_composed_futures_and_transaction() {
1086+
let database_url =
1087+
std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1088+
1089+
let mut conn = crate::AsyncPgConnection::establish(&database_url)
1090+
.await
1091+
.unwrap();
1092+
1093+
fn erase<'a, T: Future + Send + 'a>(t: T) -> impl Future<Output = T::Output> + Send + 'a {
1094+
t
1095+
}
1096+
1097+
async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1098+
let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1099+
let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1100+
1101+
erase(try_join(f1, f2)).await
1102+
}
1103+
1104+
async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1105+
let f3 = diesel::select(3_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1106+
let f4 = diesel::select(4_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1107+
1108+
erase(try_join(f3, f4)).await
1109+
}
1110+
1111+
conn.transaction(|conn| {
1112+
async move {
1113+
let f12 = fn12(conn);
1114+
let f34 = fn34(conn);
1115+
1116+
let ((r1, r2), (r3, r4)) = try_join!(f12, f34).unwrap();
1117+
1118+
assert_eq!(r1, 1);
1119+
assert_eq!(r2, 2);
1120+
assert_eq!(r3, 3);
1121+
assert_eq!(r4, 4);
1122+
1123+
QueryResult::<_>::Ok(())
1124+
}
1125+
.scope_boxed()
1126+
})
1127+
.await
1128+
.unwrap();
9641129
}
9651130
}

0 commit comments

Comments
 (0)