Skip to content

Commit a2691b9

Browse files
Reuse a cached DB connection instead of always recreating for sqlx-macros (#1782)
* refactor: Reuse a cached connection instead of always recreating for `sqlx-macros` * fix: Fix type inference issue when no database features used * refactor: Switch cached db conn to an `AnyConnection` * fix: Fix invalid variant name only exposed with features * fix: Tweak connection options for SQLite with `sqlx-macros` * fix: Remove read only option for SQLite connection * fix: Fix feature flags regarding usage of `sqlx_core::any`
1 parent fa5c436 commit a2691b9

File tree

6 files changed

+124
-65
lines changed

6 files changed

+124
-65
lines changed

sqlx-core/src/any/connection/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ mod executor;
3333
pub struct AnyConnection(pub(super) AnyConnectionKind);
3434

3535
#[derive(Debug)]
36-
pub(crate) enum AnyConnectionKind {
36+
// Used internally in `sqlx-macros`
37+
#[doc(hidden)]
38+
pub enum AnyConnectionKind {
3739
#[cfg(feature = "postgres")]
3840
Postgres(postgres::PgConnection),
3941

@@ -69,6 +71,12 @@ impl AnyConnection {
6971
pub fn kind(&self) -> AnyKind {
7072
self.0.kind()
7173
}
74+
75+
// Used internally in `sqlx-macros`
76+
#[doc(hidden)]
77+
pub fn private_get_mut(&mut self) -> &mut AnyConnectionKind {
78+
&mut self.0
79+
}
7280
}
7381

7482
macro_rules! delegate_to {

sqlx-core/src/any/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ mod migrate;
3131
pub use arguments::{AnyArgumentBuffer, AnyArguments};
3232
pub use column::{AnyColumn, AnyColumnIndex};
3333
pub use connection::AnyConnection;
34+
// Used internally in `sqlx-macros`
35+
#[doc(hidden)]
36+
pub use connection::AnyConnectionKind;
3437
pub use database::Any;
3538
pub use decode::AnyDecode;
3639
pub use encode::AnyEncode;

sqlx-core/src/pool/mod.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@
5555
//! [`Pool::begin`].
5656
5757
use self::inner::SharedPool;
58-
#[cfg(feature = "any")]
58+
#[cfg(all(
59+
any(
60+
feature = "postgres",
61+
feature = "mysql",
62+
feature = "mssql",
63+
feature = "sqlite"
64+
),
65+
feature = "any"
66+
))]
5967
use crate::any::{Any, AnyKind};
6068
use crate::connection::Connection;
6169
use crate::database::Database;
@@ -429,12 +437,19 @@ impl<DB: Database> Pool<DB> {
429437
}
430438
}
431439

432-
#[cfg(feature = "any")]
440+
#[cfg(all(
441+
any(
442+
feature = "postgres",
443+
feature = "mysql",
444+
feature = "mssql",
445+
feature = "sqlite"
446+
),
447+
feature = "any"
448+
))]
433449
impl Pool<Any> {
434450
/// Returns the database driver currently in-use by this `Pool`.
435451
///
436452
/// Determined by the connection URI.
437-
#[cfg(feature = "any")]
438453
pub fn any_kind(&self) -> AnyKind {
439454
self.0.connect_options.kind()
440455
}

sqlx-macros/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ heck = { version = "0.4", features = ["unicode"] }
8484
either = "1.6.1"
8585
once_cell = "1.9.0"
8686
proc-macro2 = { version = "1.0.36", default-features = false }
87-
sqlx-core = { version = "0.5.12", default-features = false, path = "../sqlx-core" }
87+
sqlx-core = { version = "0.5.12", default-features = false, features = ["any"], path = "../sqlx-core" }
8888
sqlx-rt = { version = "0.5.12", default-features = false, path = "../sqlx-rt" }
8989
serde = { version = "1.0.132", features = ["derive"], optional = true }
9090
serde_json = { version = "1.0.73", optional = true }

sqlx-macros/src/query/mod.rs

Lines changed: 90 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::BTreeMap;
12
use std::path::PathBuf;
23
#[cfg(feature = "offline")]
34
use std::sync::{Arc, Mutex};
@@ -12,7 +13,7 @@ use quote::{format_ident, quote};
1213
use sqlx_core::connection::Connection;
1314
use sqlx_core::database::Database;
1415
use sqlx_core::{column::Column, describe::Describe, type_info::TypeInfo};
15-
use sqlx_rt::block_on;
16+
use sqlx_rt::{block_on, AsyncMutex};
1617

1718
use crate::database::DatabaseExt;
1819
use crate::query::data::QueryData;
@@ -117,6 +118,28 @@ static METADATA: Lazy<Metadata> = Lazy::new(|| {
117118

118119
pub fn expand_input(input: QueryMacroInput) -> crate::Result<TokenStream> {
119120
match &*METADATA {
121+
#[cfg(not(any(
122+
feature = "postgres",
123+
feature = "mysql",
124+
feature = "mssql",
125+
feature = "sqlite"
126+
)))]
127+
Metadata {
128+
offline: false,
129+
database_url: Some(db_url),
130+
..
131+
} => Err(
132+
"At least one of the features ['postgres', 'mysql', 'mssql', 'sqlite'] must be enabled \
133+
to get information directly from a database"
134+
.into(),
135+
),
136+
137+
#[cfg(any(
138+
feature = "postgres",
139+
feature = "mysql",
140+
feature = "mssql",
141+
feature = "sqlite"
142+
))]
120143
Metadata {
121144
offline: false,
122145
database_url: Some(db_url),
@@ -157,67 +180,76 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result<TokenStream> {
157180
}
158181
}
159182

160-
#[allow(unused_variables)]
183+
#[cfg(any(
184+
feature = "postgres",
185+
feature = "mysql",
186+
feature = "mssql",
187+
feature = "sqlite"
188+
))]
161189
fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenStream> {
162-
// FIXME: Introduce [sqlx::any::AnyConnection] and [sqlx::any::AnyDatabase] to support
163-
// runtime determinism here
164-
165-
let db_url = Url::parse(db_url)?;
166-
match db_url.scheme() {
167-
#[cfg(feature = "postgres")]
168-
"postgres" | "postgresql" => {
169-
let data = block_on(async {
170-
let mut conn = sqlx_core::postgres::PgConnection::connect(db_url.as_str()).await?;
171-
QueryData::from_db(&mut conn, &input.sql).await
172-
})?;
173-
174-
expand_with_data(input, data, false)
175-
},
176-
177-
#[cfg(not(feature = "postgres"))]
178-
"postgres" | "postgresql" => Err("database URL has the scheme of a PostgreSQL database but the `postgres` feature is not enabled".into()),
179-
180-
#[cfg(feature = "mssql")]
181-
"mssql" | "sqlserver" => {
182-
let data = block_on(async {
183-
let mut conn = sqlx_core::mssql::MssqlConnection::connect(db_url.as_str()).await?;
184-
QueryData::from_db(&mut conn, &input.sql).await
185-
})?;
186-
187-
expand_with_data(input, data, false)
188-
},
189-
190-
#[cfg(not(feature = "mssql"))]
191-
"mssql" | "sqlserver" => Err("database URL has the scheme of a MSSQL database but the `mssql` feature is not enabled".into()),
192-
193-
#[cfg(feature = "mysql")]
194-
"mysql" | "mariadb" => {
195-
let data = block_on(async {
196-
let mut conn = sqlx_core::mysql::MySqlConnection::connect(db_url.as_str()).await?;
197-
QueryData::from_db(&mut conn, &input.sql).await
198-
})?;
199-
200-
expand_with_data(input, data, false)
201-
},
202-
203-
#[cfg(not(feature = "mysql"))]
204-
"mysql" | "mariadb" => Err("database URL has the scheme of a MySQL/MariaDB database but the `mysql` feature is not enabled".into()),
205-
206-
#[cfg(feature = "sqlite")]
207-
"sqlite" => {
208-
let data = block_on(async {
209-
let mut conn = sqlx_core::sqlite::SqliteConnection::connect(db_url.as_str()).await?;
210-
QueryData::from_db(&mut conn, &input.sql).await
211-
})?;
190+
use sqlx_core::any::{AnyConnection, AnyConnectionKind};
191+
192+
static CONNECTION_CACHE: Lazy<AsyncMutex<BTreeMap<String, AnyConnection>>> =
193+
Lazy::new(|| AsyncMutex::new(BTreeMap::new()));
194+
195+
let maybe_expanded: crate::Result<TokenStream> = block_on(async {
196+
let mut cache = CONNECTION_CACHE.lock().await;
197+
198+
if !cache.contains_key(db_url) {
199+
let parsed_db_url = Url::parse(db_url)?;
200+
201+
let conn = match parsed_db_url.scheme() {
202+
#[cfg(feature = "sqlite")]
203+
"sqlite" => {
204+
use sqlx_core::connection::ConnectOptions;
205+
use sqlx_core::sqlite::{SqliteConnectOptions, SqliteJournalMode};
206+
use std::str::FromStr;
207+
208+
let sqlite_conn = SqliteConnectOptions::from_str(db_url)?
209+
// Connections in `CONNECTION_CACHE` won't get dropped so disable journaling
210+
// to avoid `.db-wal` and `.db-shm` files from lingering around
211+
.journal_mode(SqliteJournalMode::Off)
212+
.connect()
213+
.await?;
214+
AnyConnection::from(sqlite_conn)
215+
}
216+
_ => AnyConnection::connect(db_url).await?,
217+
};
212218

213-
expand_with_data(input, data, false)
214-
},
219+
let _ = cache.insert(db_url.to_owned(), conn);
220+
}
215221

216-
#[cfg(not(feature = "sqlite"))]
217-
"sqlite" => Err("database URL has the scheme of a SQLite database but the `sqlite` feature is not enabled".into()),
222+
let conn_item = cache.get_mut(db_url).expect("Item was just inserted");
223+
match conn_item.private_get_mut() {
224+
#[cfg(feature = "postgres")]
225+
AnyConnectionKind::Postgres(conn) => {
226+
let data = QueryData::from_db(conn, &input.sql).await?;
227+
expand_with_data(input, data, false)
228+
}
229+
#[cfg(feature = "mssql")]
230+
AnyConnectionKind::Mssql(conn) => {
231+
let data = QueryData::from_db(conn, &input.sql).await?;
232+
expand_with_data(input, data, false)
233+
}
234+
#[cfg(feature = "mysql")]
235+
AnyConnectionKind::MySql(conn) => {
236+
let data = QueryData::from_db(conn, &input.sql).await?;
237+
expand_with_data(input, data, false)
238+
}
239+
#[cfg(feature = "sqlite")]
240+
AnyConnectionKind::Sqlite(conn) => {
241+
let data = QueryData::from_db(conn, &input.sql).await?;
242+
expand_with_data(input, data, false)
243+
}
244+
// Variants depend on feature flags
245+
#[allow(unreachable_patterns)]
246+
item => {
247+
return Err(format!("Missing expansion needed for: {:?}", item).into());
248+
}
249+
}
250+
});
218251

219-
scheme => Err(format!("unknown database URL scheme {:?}", scheme).into())
220-
}
252+
maybe_expanded.map_err(Into::into)
221253
}
222254

223255
#[cfg(feature = "offline")]

sqlx-rt/src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ pub use native_tls;
3737
))]
3838
pub use tokio::{
3939
self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf,
40-
net::TcpStream, runtime::Handle, task::spawn, task::yield_now, time::sleep, time::timeout,
40+
net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now,
41+
time::sleep, time::timeout,
4142
};
4243

4344
#[cfg(all(
@@ -142,7 +143,7 @@ macro_rules! blocking {
142143
pub use async_std::{
143144
self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt,
144145
io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite,
145-
net::TcpStream, task::sleep, task::spawn, task::yield_now,
146+
net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::yield_now,
146147
};
147148

148149
#[cfg(all(

0 commit comments

Comments
 (0)