11mod engines;
22
3- use std:: collections:: { BTreeMap , BTreeSet } ;
3+ use std:: collections:: { BTreeMap , BTreeSet , HashSet } ;
44use std:: io:: { stdout, Read , Seek , SeekFrom , Write } ;
55use std:: path:: { Path , PathBuf } ;
66use std:: time:: { Duration , Instant } ;
@@ -17,9 +17,10 @@ use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite};
1717use rand:: distributions:: DistString ;
1818use rand:: seq:: SliceRandom ;
1919use sqllogictest:: {
20- default_validator , strict_column_validator , update_record_with_output , AsyncDB , Injected ,
21- MakeConnection , Record , Runner ,
20+ default_column_validator , default_normalizer , default_validator , update_record_with_output ,
21+ AsyncDB , Injected , MakeConnection , Record , Runner ,
2222} ;
23+ use tokio_util:: task:: AbortOnDropHandle ;
2324
2425#[ derive( Default , Copy , Clone , Debug , PartialEq , Eq , ValueEnum ) ]
2526#[ must_use]
@@ -62,6 +63,13 @@ struct Opt {
6263 /// database will be created for each test file.
6364 #[ clap( long, short) ]
6465 jobs : Option < usize > ,
66+ /// When using `-j`, whether to keep the temporary database when a test case fails.
67+ #[ clap( long, default_value = "false" , env = "SLT_KEEP_DB_ON_FAILURE" ) ]
68+ keep_db_on_failure : bool ,
69+
70+ /// Whether to exit immediately when a test case fails.
71+ #[ clap( long, default_value = "false" , env = "SLT_FAIL_FAST" ) ]
72+ fail_fast : bool ,
6573
6674 /// Report to junit XML.
6775 #[ clap( long) ]
@@ -146,6 +154,8 @@ pub async fn main() -> Result<()> {
146154 external_engine_command_template,
147155 color,
148156 jobs,
157+ keep_db_on_failure,
158+ fail_fast,
149159 junit,
150160 host,
151161 port,
@@ -228,12 +238,14 @@ pub async fn main() -> Result<()> {
228238 let result = if let Some ( jobs) = jobs {
229239 run_parallel (
230240 jobs,
241+ keep_db_on_failure,
231242 & mut test_suite,
232243 files,
233244 & engine,
234245 config,
235246 & labels,
236247 junit. clone ( ) ,
248+ fail_fast,
237249 )
238250 . await
239251 } else {
@@ -244,6 +256,7 @@ pub async fn main() -> Result<()> {
244256 config,
245257 & labels,
246258 junit. clone ( ) ,
259+ fail_fast,
247260 )
248261 . await
249262 } ;
@@ -257,14 +270,17 @@ pub async fn main() -> Result<()> {
257270 result
258271}
259272
273+ #[ allow( clippy:: too_many_arguments) ]
260274async fn run_parallel (
261275 jobs : usize ,
276+ keep_db_on_failure : bool ,
262277 test_suite : & mut TestSuite ,
263278 files : Vec < PathBuf > ,
264279 engine : & EngineConfig ,
265280 config : DBConfig ,
266281 labels : & [ String ] ,
267282 junit : Option < String > ,
283+ fail_fast : bool ,
268284) -> Result < ( ) > {
269285 let mut create_databases = BTreeMap :: new ( ) ;
270286 let mut filenames = BTreeSet :: new ( ) ;
@@ -299,36 +315,40 @@ async fn run_parallel(
299315 }
300316 }
301317
302- let mut stream = futures:: stream:: iter ( create_databases. into_iter ( ) )
318+ let mut stream = futures:: stream:: iter ( create_databases)
303319 . map ( |( db_name, filename) | {
304320 let mut config = config. clone ( ) ;
305- config. db = db_name;
321+ config. db . clone_from ( & db_name) ;
306322 let file = filename. to_string_lossy ( ) . to_string ( ) ;
307323 let engine = engine. clone ( ) ;
308324 let labels = labels. to_vec ( ) ;
309325 async move {
310- let ( buf, res) = tokio:: spawn ( async move {
326+ let ( buf, res) = AbortOnDropHandle :: new ( tokio:: spawn ( async move {
311327 let mut buf = vec ! [ ] ;
312328 let res =
313329 connect_and_run_test_file ( & mut buf, filename, & engine, config, & labels)
314330 . await ;
315331 ( buf, res)
316- } )
332+ } ) )
317333 . await
318334 . unwrap ( ) ;
319- ( file, res, buf)
335+ ( db_name , file, res, buf)
320336 }
321337 } )
322338 . buffer_unordered ( jobs) ;
323339
324340 eprintln ! ( "{}" , style( "[TEST IN PROGRESS]" ) . blue( ) . bold( ) ) ;
325341
326342 let mut failed_case = vec ! [ ] ;
343+ let mut failed_db: HashSet < String > = HashSet :: new ( ) ;
344+ let mut remaining_files: HashSet < String > = HashSet :: from_iter ( filenames. clone ( ) ) ;
327345
328346 let start = Instant :: now ( ) ;
329-
330- while let Some ( ( file, res, mut buf) ) = stream. next ( ) . await {
347+ let mut connection_refused = false ;
348+ while let Some ( ( db_name, file, res, mut buf) ) = stream. next ( ) . await {
349+ remaining_files. remove ( & file) ;
331350 let test_case_name = file. replace ( [ '/' , ' ' , '.' , '-' ] , "_" ) ;
351+ let mut failed = false ;
332352 let case = match res {
333353 Ok ( duration) => {
334354 let mut case = TestCase :: new ( test_case_name, TestCaseStatus :: success ( ) ) ;
@@ -338,9 +358,15 @@ async fn run_parallel(
338358 case
339359 }
340360 Err ( e) => {
341- writeln ! ( buf, "{}\n \n {:?}" , style( "[FAILED]" ) . red( ) . bold( ) , e) ?;
361+ failed = true ;
362+ let err = format ! ( "{:?}" , e) ;
363+ if err. contains ( "Connection refused" ) {
364+ connection_refused = true ;
365+ }
366+ writeln ! ( buf, "{}\n \n {}" , style( "[FAILED]" ) . red( ) . bold( ) , err) ?;
342367 writeln ! ( buf) ?;
343368 failed_case. push ( file. clone ( ) ) ;
369+ failed_db. insert ( db_name. clone ( ) ) ;
344370 let mut status = TestCaseStatus :: non_success ( NonSuccessKind :: Failure ) ;
345371 status. set_type ( "test failure" ) ;
346372 let mut case = TestCase :: new ( test_case_name, status) ;
@@ -354,18 +380,60 @@ async fn run_parallel(
354380 } ;
355381 test_suite. add_test_case ( case) ;
356382 tokio:: task:: block_in_place ( || stdout ( ) . write_all ( & buf) ) ?;
383+ if connection_refused {
384+ eprintln ! ( "Connection refused. The server may be down. Exiting..." ) ;
385+ break ;
386+ }
387+ if fail_fast && failed {
388+ println ! ( "early exit after failure..." ) ;
389+ break ;
390+ }
391+ }
392+
393+ for file in remaining_files {
394+ println ! ( "{file} is not finished, skipping" ) ;
395+ let test_case_name = file. replace ( [ '/' , ' ' , '.' , '-' ] , "_" ) ;
396+ let mut case = TestCase :: new ( test_case_name, TestCaseStatus :: skipped ( ) ) ;
397+ case. set_time ( Duration :: from_millis ( 0 ) ) ;
398+ case. set_timestamp ( Local :: now ( ) ) ;
399+ case. set_classname ( junit. as_deref ( ) . unwrap_or_default ( ) ) ;
400+ test_suite. add_test_case ( case) ;
357401 }
358402
359403 eprintln ! (
360404 "\n All test cases finished in {} ms" ,
361405 start. elapsed( ) . as_millis( )
362406 ) ;
363407
364- for db_name in db_names {
365- let query = format ! ( "DROP DATABASE {db_name};" ) ;
366- eprintln ! ( "+ {query}" ) ;
367- if let Err ( err) = db. run ( & query) . await {
368- eprintln ! ( " ignore error: {err}" ) ;
408+ // If `fail_fast`, there could be some ongoing cases (then active connections)
409+ // in the stream. Abort them before dropping temporary databases.
410+ drop ( stream) ;
411+
412+ if connection_refused {
413+ eprintln ! ( "Skip dropping databases due to connection refused: {db_names:?}" ) ;
414+ } else {
415+ for db_name in db_names {
416+ if keep_db_on_failure && failed_db. contains ( & db_name) {
417+ eprintln ! (
418+ "+ {}" ,
419+ style( format!(
420+ "DATABASE {db_name} contains failed cases, kept for debugging"
421+ ) )
422+ . red( )
423+ . bold( )
424+ ) ;
425+ continue ;
426+ }
427+ let query = format ! ( "DROP DATABASE {db_name};" ) ;
428+ eprintln ! ( "+ {query}" ) ;
429+ if let Err ( err) = db. run ( & query) . await {
430+ let err = err. to_string ( ) ;
431+ if err. contains ( "Connection refused" ) {
432+ eprintln ! ( " Connection refused. The server may be down. Exiting..." ) ;
433+ break ;
434+ }
435+ eprintln ! ( " ignore DROP DATABASE error: {err}" ) ;
436+ }
369437 }
370438 }
371439
@@ -384,17 +452,21 @@ async fn run_serial(
384452 config : DBConfig ,
385453 labels : & [ String ] ,
386454 junit : Option < String > ,
455+ fail_fast : bool ,
387456) -> Result < ( ) > {
388457 let mut failed_case = vec ! [ ] ;
389-
390- for file in files {
458+ let mut skipped_case = vec ! [ ] ;
459+ let mut files = files. into_iter ( ) ;
460+ let mut connection_refused = false ;
461+ for file in & mut files {
391462 let mut runner = Runner :: new ( || engines:: connect ( engine, & config) ) ;
392463 for label in labels {
393464 runner. add_label ( label) ;
394465 }
395466
396467 let filename = file. to_string_lossy ( ) . to_string ( ) ;
397468 let test_case_name = filename. replace ( [ '/' , ' ' , '.' , '-' ] , "_" ) ;
469+ let mut failed = false ;
398470 let case = match run_test_file ( & mut std:: io:: stdout ( ) , runner, & file) . await {
399471 Ok ( duration) => {
400472 let mut case = TestCase :: new ( test_case_name, TestCaseStatus :: success ( ) ) ;
@@ -404,7 +476,12 @@ async fn run_serial(
404476 case
405477 }
406478 Err ( e) => {
407- println ! ( "{}\n \n {:?}" , style( "[FAILED]" ) . red( ) . bold( ) , e) ;
479+ failed = true ;
480+ let err = format ! ( "{:?}" , e) ;
481+ if err. contains ( "Connection refused" ) {
482+ connection_refused = true ;
483+ }
484+ println ! ( "{}\n \n {}" , style( "[FAILED]" ) . red( ) . bold( ) , err) ;
408485 println ! ( ) ;
409486 failed_case. push ( filename. clone ( ) ) ;
410487 let mut status = TestCaseStatus :: non_success ( NonSuccessKind :: Failure ) ;
@@ -419,6 +496,27 @@ async fn run_serial(
419496 }
420497 } ;
421498 test_suite. add_test_case ( case) ;
499+ if connection_refused {
500+ eprintln ! ( "Connection refused. The server may be down. Exiting..." ) ;
501+ break ;
502+ }
503+ if fail_fast && failed {
504+ println ! ( "early exit after failure..." ) ;
505+ break ;
506+ }
507+ }
508+ for file in files {
509+ let filename = file. to_string_lossy ( ) . to_string ( ) ;
510+ let test_case_name = filename. replace ( [ '/' , ' ' , '.' , '-' ] , "_" ) ;
511+ let mut case = TestCase :: new ( test_case_name, TestCaseStatus :: skipped ( ) ) ;
512+ case. set_time ( Duration :: from_millis ( 0 ) ) ;
513+ case. set_timestamp ( Local :: now ( ) ) ;
514+ case. set_classname ( junit. as_deref ( ) . unwrap_or_default ( ) ) ;
515+ test_suite. add_test_case ( case) ;
516+ skipped_case. push ( filename. clone ( ) ) ;
517+ }
518+ if !skipped_case. is_empty ( ) {
519+ println ! ( "some test case skipped:\n {:#?}" , skipped_case) ;
422520 }
423521
424522 if !failed_case. is_empty ( ) {
@@ -750,7 +848,8 @@ async fn update_record<M: MakeConnection>(
750848 & record_output,
751849 "\t " ,
752850 default_validator,
753- strict_column_validator,
851+ default_normalizer,
852+ default_column_validator,
754853 ) {
755854 Some ( new_record) => {
756855 writeln ! ( outfile, "{new_record}" ) ?;
0 commit comments