1
+ use std:: collections:: BTreeMap ;
1
2
use std:: path:: PathBuf ;
2
3
#[ cfg( feature = "offline" ) ]
3
4
use std:: sync:: { Arc , Mutex } ;
@@ -12,7 +13,7 @@ use quote::{format_ident, quote};
12
13
use sqlx_core:: connection:: Connection ;
13
14
use sqlx_core:: database:: Database ;
14
15
use sqlx_core:: { column:: Column , describe:: Describe , type_info:: TypeInfo } ;
15
- use sqlx_rt:: block_on;
16
+ use sqlx_rt:: { block_on, AsyncMutex } ;
16
17
17
18
use crate :: database:: DatabaseExt ;
18
19
use crate :: query:: data:: QueryData ;
@@ -117,6 +118,28 @@ static METADATA: Lazy<Metadata> = Lazy::new(|| {
117
118
118
119
pub fn expand_input ( input : QueryMacroInput ) -> crate :: Result < TokenStream > {
119
120
match & * METADATA {
121
+ #[ cfg( not( any(
122
+ feature = "postgres" ,
123
+ feature = "mysql" ,
124
+ feature = "mssql" ,
125
+ feature = "sqlite"
126
+ ) ) ) ]
127
+ Metadata {
128
+ offline : false ,
129
+ database_url : Some ( db_url) ,
130
+ ..
131
+ } => Err (
132
+ "At least one of the features ['postgres', 'mysql', 'mssql', 'sqlite'] must be enabled \
133
+ to get information directly from a database"
134
+ . into ( ) ,
135
+ ) ,
136
+
137
+ #[ cfg( any(
138
+ feature = "postgres" ,
139
+ feature = "mysql" ,
140
+ feature = "mssql" ,
141
+ feature = "sqlite"
142
+ ) ) ]
120
143
Metadata {
121
144
offline : false ,
122
145
database_url : Some ( db_url) ,
@@ -157,67 +180,76 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result<TokenStream> {
157
180
}
158
181
}
159
182
160
- #[ allow( unused_variables) ]
183
+ #[ cfg( any(
184
+ feature = "postgres" ,
185
+ feature = "mysql" ,
186
+ feature = "mssql" ,
187
+ feature = "sqlite"
188
+ ) ) ]
161
189
fn expand_from_db ( input : QueryMacroInput , db_url : & str ) -> crate :: Result < TokenStream > {
162
- // FIXME: Introduce [sqlx::any::AnyConnection] and [sqlx::any::AnyDatabase] to support
163
- // runtime determinism here
164
-
165
- let db_url = Url :: parse ( db_url) ?;
166
- match db_url. scheme ( ) {
167
- #[ cfg( feature = "postgres" ) ]
168
- "postgres" | "postgresql" => {
169
- let data = block_on ( async {
170
- let mut conn = sqlx_core:: postgres:: PgConnection :: connect ( db_url. as_str ( ) ) . await ?;
171
- QueryData :: from_db ( & mut conn, & input. sql ) . await
172
- } ) ?;
173
-
174
- expand_with_data ( input, data, false )
175
- } ,
176
-
177
- #[ cfg( not( feature = "postgres" ) ) ]
178
- "postgres" | "postgresql" => Err ( "database URL has the scheme of a PostgreSQL database but the `postgres` feature is not enabled" . into ( ) ) ,
179
-
180
- #[ cfg( feature = "mssql" ) ]
181
- "mssql" | "sqlserver" => {
182
- let data = block_on ( async {
183
- let mut conn = sqlx_core:: mssql:: MssqlConnection :: connect ( db_url. as_str ( ) ) . await ?;
184
- QueryData :: from_db ( & mut conn, & input. sql ) . await
185
- } ) ?;
186
-
187
- expand_with_data ( input, data, false )
188
- } ,
189
-
190
- #[ cfg( not( feature = "mssql" ) ) ]
191
- "mssql" | "sqlserver" => Err ( "database URL has the scheme of a MSSQL database but the `mssql` feature is not enabled" . into ( ) ) ,
192
-
193
- #[ cfg( feature = "mysql" ) ]
194
- "mysql" | "mariadb" => {
195
- let data = block_on ( async {
196
- let mut conn = sqlx_core:: mysql:: MySqlConnection :: connect ( db_url. as_str ( ) ) . await ?;
197
- QueryData :: from_db ( & mut conn, & input. sql ) . await
198
- } ) ?;
199
-
200
- expand_with_data ( input, data, false )
201
- } ,
202
-
203
- #[ cfg( not( feature = "mysql" ) ) ]
204
- "mysql" | "mariadb" => Err ( "database URL has the scheme of a MySQL/MariaDB database but the `mysql` feature is not enabled" . into ( ) ) ,
205
-
206
- #[ cfg( feature = "sqlite" ) ]
207
- "sqlite" => {
208
- let data = block_on ( async {
209
- let mut conn = sqlx_core:: sqlite:: SqliteConnection :: connect ( db_url. as_str ( ) ) . await ?;
210
- QueryData :: from_db ( & mut conn, & input. sql ) . await
211
- } ) ?;
190
+ use sqlx_core:: any:: { AnyConnection , AnyConnectionKind } ;
191
+
192
+ static CONNECTION_CACHE : Lazy < AsyncMutex < BTreeMap < String , AnyConnection > > > =
193
+ Lazy :: new ( || AsyncMutex :: new ( BTreeMap :: new ( ) ) ) ;
194
+
195
+ let maybe_expanded: crate :: Result < TokenStream > = block_on ( async {
196
+ let mut cache = CONNECTION_CACHE . lock ( ) . await ;
197
+
198
+ if !cache. contains_key ( db_url) {
199
+ let parsed_db_url = Url :: parse ( db_url) ?;
200
+
201
+ let conn = match parsed_db_url. scheme ( ) {
202
+ #[ cfg( feature = "sqlite" ) ]
203
+ "sqlite" => {
204
+ use sqlx_core:: connection:: ConnectOptions ;
205
+ use sqlx_core:: sqlite:: { SqliteConnectOptions , SqliteJournalMode } ;
206
+ use std:: str:: FromStr ;
207
+
208
+ let sqlite_conn = SqliteConnectOptions :: from_str ( db_url) ?
209
+ // Connections in `CONNECTION_CACHE` won't get dropped so disable journaling
210
+ // to avoid `.db-wal` and `.db-shm` files from lingering around
211
+ . journal_mode ( SqliteJournalMode :: Off )
212
+ . connect ( )
213
+ . await ?;
214
+ AnyConnection :: from ( sqlite_conn)
215
+ }
216
+ _ => AnyConnection :: connect ( db_url) . await ?,
217
+ } ;
212
218
213
- expand_with_data ( input , data , false )
214
- } ,
219
+ let _ = cache . insert ( db_url . to_owned ( ) , conn ) ;
220
+ }
215
221
216
- #[ cfg( not( feature = "sqlite" ) ) ]
217
- "sqlite" => Err ( "database URL has the scheme of a SQLite database but the `sqlite` feature is not enabled" . into ( ) ) ,
222
+ let conn_item = cache. get_mut ( db_url) . expect ( "Item was just inserted" ) ;
223
+ match conn_item. private_get_mut ( ) {
224
+ #[ cfg( feature = "postgres" ) ]
225
+ AnyConnectionKind :: Postgres ( conn) => {
226
+ let data = QueryData :: from_db ( conn, & input. sql ) . await ?;
227
+ expand_with_data ( input, data, false )
228
+ }
229
+ #[ cfg( feature = "mssql" ) ]
230
+ AnyConnectionKind :: Mssql ( conn) => {
231
+ let data = QueryData :: from_db ( conn, & input. sql ) . await ?;
232
+ expand_with_data ( input, data, false )
233
+ }
234
+ #[ cfg( feature = "mysql" ) ]
235
+ AnyConnectionKind :: MySql ( conn) => {
236
+ let data = QueryData :: from_db ( conn, & input. sql ) . await ?;
237
+ expand_with_data ( input, data, false )
238
+ }
239
+ #[ cfg( feature = "sqlite" ) ]
240
+ AnyConnectionKind :: Sqlite ( conn) => {
241
+ let data = QueryData :: from_db ( conn, & input. sql ) . await ?;
242
+ expand_with_data ( input, data, false )
243
+ }
244
+ // Variants depend on feature flags
245
+ #[ allow( unreachable_patterns) ]
246
+ item => {
247
+ return Err ( format ! ( "Missing expansion needed for: {:?}" , item) . into ( ) ) ;
248
+ }
249
+ }
250
+ } ) ;
218
251
219
- scheme => Err ( format ! ( "unknown database URL scheme {:?}" , scheme) . into ( ) )
220
- }
252
+ maybe_expanded. map_err ( Into :: into)
221
253
}
222
254
223
255
#[ cfg( feature = "offline" ) ]
0 commit comments