1
1
mod engines;
2
2
3
- use std:: collections:: { BTreeMap , BTreeSet } ;
3
+ use std:: collections:: { BTreeMap , BTreeSet , HashSet } ;
4
4
use std:: io:: { stdout, Read , Seek , SeekFrom , Write } ;
5
5
use std:: path:: { Path , PathBuf } ;
6
6
use std:: time:: { Duration , Instant } ;
@@ -17,9 +17,10 @@ use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite};
17
17
use rand:: distributions:: DistString ;
18
18
use rand:: seq:: SliceRandom ;
19
19
use sqllogictest:: {
20
- default_validator , strict_column_validator , update_record_with_output , AsyncDB , Injected ,
21
- MakeConnection , Record , Runner ,
20
+ default_column_validator , default_normalizer , default_validator , update_record_with_output ,
21
+ AsyncDB , Injected , MakeConnection , Record , Runner ,
22
22
} ;
23
+ use tokio_util:: task:: AbortOnDropHandle ;
23
24
24
25
#[ derive( Default , Copy , Clone , Debug , PartialEq , Eq , ValueEnum ) ]
25
26
#[ must_use]
@@ -62,6 +63,13 @@ struct Opt {
62
63
/// database will be created for each test file.
63
64
#[ clap( long, short) ]
64
65
jobs : Option < usize > ,
66
+ /// When using `-j`, whether to keep the temporary database when a test case fails.
67
+ #[ clap( long, default_value = "false" , env = "SLT_KEEP_DB_ON_FAILURE" ) ]
68
+ keep_db_on_failure : bool ,
69
+
70
+ /// Whether to exit immediately when a test case fails.
71
+ #[ clap( long, default_value = "false" , env = "SLT_FAIL_FAST" ) ]
72
+ fail_fast : bool ,
65
73
66
74
/// Report to junit XML.
67
75
#[ clap( long) ]
@@ -146,6 +154,8 @@ pub async fn main() -> Result<()> {
146
154
external_engine_command_template,
147
155
color,
148
156
jobs,
157
+ keep_db_on_failure,
158
+ fail_fast,
149
159
junit,
150
160
host,
151
161
port,
@@ -228,12 +238,14 @@ pub async fn main() -> Result<()> {
228
238
let result = if let Some ( jobs) = jobs {
229
239
run_parallel (
230
240
jobs,
241
+ keep_db_on_failure,
231
242
& mut test_suite,
232
243
files,
233
244
& engine,
234
245
config,
235
246
& labels,
236
247
junit. clone ( ) ,
248
+ fail_fast,
237
249
)
238
250
. await
239
251
} else {
@@ -244,6 +256,7 @@ pub async fn main() -> Result<()> {
244
256
config,
245
257
& labels,
246
258
junit. clone ( ) ,
259
+ fail_fast,
247
260
)
248
261
. await
249
262
} ;
@@ -257,14 +270,17 @@ pub async fn main() -> Result<()> {
257
270
result
258
271
}
259
272
273
+ #[ allow( clippy:: too_many_arguments) ]
260
274
async fn run_parallel (
261
275
jobs : usize ,
276
+ keep_db_on_failure : bool ,
262
277
test_suite : & mut TestSuite ,
263
278
files : Vec < PathBuf > ,
264
279
engine : & EngineConfig ,
265
280
config : DBConfig ,
266
281
labels : & [ String ] ,
267
282
junit : Option < String > ,
283
+ fail_fast : bool ,
268
284
) -> Result < ( ) > {
269
285
let mut create_databases = BTreeMap :: new ( ) ;
270
286
let mut filenames = BTreeSet :: new ( ) ;
@@ -299,36 +315,40 @@ async fn run_parallel(
299
315
}
300
316
}
301
317
302
- let mut stream = futures:: stream:: iter ( create_databases. into_iter ( ) )
318
+ let mut stream = futures:: stream:: iter ( create_databases)
303
319
. map ( |( db_name, filename) | {
304
320
let mut config = config. clone ( ) ;
305
- config. db = db_name;
321
+ config. db . clone_from ( & db_name) ;
306
322
let file = filename. to_string_lossy ( ) . to_string ( ) ;
307
323
let engine = engine. clone ( ) ;
308
324
let labels = labels. to_vec ( ) ;
309
325
async move {
310
- let ( buf, res) = tokio:: spawn ( async move {
326
+ let ( buf, res) = AbortOnDropHandle :: new ( tokio:: spawn ( async move {
311
327
let mut buf = vec ! [ ] ;
312
328
let res =
313
329
connect_and_run_test_file ( & mut buf, filename, & engine, config, & labels)
314
330
. await ;
315
331
( buf, res)
316
- } )
332
+ } ) )
317
333
. await
318
334
. unwrap ( ) ;
319
- ( file, res, buf)
335
+ ( db_name , file, res, buf)
320
336
}
321
337
} )
322
338
. buffer_unordered ( jobs) ;
323
339
324
340
eprintln ! ( "{}" , style( "[TEST IN PROGRESS]" ) . blue( ) . bold( ) ) ;
325
341
326
342
let mut failed_case = vec ! [ ] ;
343
+ let mut failed_db: HashSet < String > = HashSet :: new ( ) ;
344
+ let mut remaining_files: HashSet < String > = HashSet :: from_iter ( filenames. clone ( ) ) ;
327
345
328
346
let start = Instant :: now ( ) ;
329
-
330
- while let Some ( ( file, res, mut buf) ) = stream. next ( ) . await {
347
+ let mut connection_refused = false ;
348
+ while let Some ( ( db_name, file, res, mut buf) ) = stream. next ( ) . await {
349
+ remaining_files. remove ( & file) ;
331
350
let test_case_name = file. replace ( [ '/' , ' ' , '.' , '-' ] , "_" ) ;
351
+ let mut failed = false ;
332
352
let case = match res {
333
353
Ok ( duration) => {
334
354
let mut case = TestCase :: new ( test_case_name, TestCaseStatus :: success ( ) ) ;
@@ -338,9 +358,15 @@ async fn run_parallel(
338
358
case
339
359
}
340
360
Err ( e) => {
341
- writeln ! ( buf, "{}\n \n {:?}" , style( "[FAILED]" ) . red( ) . bold( ) , e) ?;
361
+ failed = true ;
362
+ let err = format ! ( "{:?}" , e) ;
363
+ if err. contains ( "Connection refused" ) {
364
+ connection_refused = true ;
365
+ }
366
+ writeln ! ( buf, "{}\n \n {}" , style( "[FAILED]" ) . red( ) . bold( ) , err) ?;
342
367
writeln ! ( buf) ?;
343
368
failed_case. push ( file. clone ( ) ) ;
369
+ failed_db. insert ( db_name. clone ( ) ) ;
344
370
let mut status = TestCaseStatus :: non_success ( NonSuccessKind :: Failure ) ;
345
371
status. set_type ( "test failure" ) ;
346
372
let mut case = TestCase :: new ( test_case_name, status) ;
@@ -354,18 +380,60 @@ async fn run_parallel(
354
380
} ;
355
381
test_suite. add_test_case ( case) ;
356
382
tokio:: task:: block_in_place ( || stdout ( ) . write_all ( & buf) ) ?;
383
+ if connection_refused {
384
+ eprintln ! ( "Connection refused. The server may be down. Exiting..." ) ;
385
+ break ;
386
+ }
387
+ if fail_fast && failed {
388
+ println ! ( "early exit after failure..." ) ;
389
+ break ;
390
+ }
391
+ }
392
+
393
+ for file in remaining_files {
394
+ println ! ( "{file} is not finished, skipping" ) ;
395
+ let test_case_name = file. replace ( [ '/' , ' ' , '.' , '-' ] , "_" ) ;
396
+ let mut case = TestCase :: new ( test_case_name, TestCaseStatus :: skipped ( ) ) ;
397
+ case. set_time ( Duration :: from_millis ( 0 ) ) ;
398
+ case. set_timestamp ( Local :: now ( ) ) ;
399
+ case. set_classname ( junit. as_deref ( ) . unwrap_or_default ( ) ) ;
400
+ test_suite. add_test_case ( case) ;
357
401
}
358
402
359
403
eprintln ! (
360
404
"\n All test cases finished in {} ms" ,
361
405
start. elapsed( ) . as_millis( )
362
406
) ;
363
407
364
- for db_name in db_names {
365
- let query = format ! ( "DROP DATABASE {db_name};" ) ;
366
- eprintln ! ( "+ {query}" ) ;
367
- if let Err ( err) = db. run ( & query) . await {
368
- eprintln ! ( " ignore error: {err}" ) ;
408
+ // If `fail_fast`, there could be some ongoing cases (then active connections)
409
+ // in the stream. Abort them before dropping temporary databases.
410
+ drop ( stream) ;
411
+
412
+ if connection_refused {
413
+ eprintln ! ( "Skip dropping databases due to connection refused: {db_names:?}" ) ;
414
+ } else {
415
+ for db_name in db_names {
416
+ if keep_db_on_failure && failed_db. contains ( & db_name) {
417
+ eprintln ! (
418
+ "+ {}" ,
419
+ style( format!(
420
+ "DATABASE {db_name} contains failed cases, kept for debugging"
421
+ ) )
422
+ . red( )
423
+ . bold( )
424
+ ) ;
425
+ continue ;
426
+ }
427
+ let query = format ! ( "DROP DATABASE {db_name};" ) ;
428
+ eprintln ! ( "+ {query}" ) ;
429
+ if let Err ( err) = db. run ( & query) . await {
430
+ let err = err. to_string ( ) ;
431
+ if err. contains ( "Connection refused" ) {
432
+ eprintln ! ( " Connection refused. The server may be down. Exiting..." ) ;
433
+ break ;
434
+ }
435
+ eprintln ! ( " ignore DROP DATABASE error: {err}" ) ;
436
+ }
369
437
}
370
438
}
371
439
@@ -384,17 +452,21 @@ async fn run_serial(
384
452
config : DBConfig ,
385
453
labels : & [ String ] ,
386
454
junit : Option < String > ,
455
+ fail_fast : bool ,
387
456
) -> Result < ( ) > {
388
457
let mut failed_case = vec ! [ ] ;
389
-
390
- for file in files {
458
+ let mut skipped_case = vec ! [ ] ;
459
+ let mut files = files. into_iter ( ) ;
460
+ let mut connection_refused = false ;
461
+ for file in & mut files {
391
462
let mut runner = Runner :: new ( || engines:: connect ( engine, & config) ) ;
392
463
for label in labels {
393
464
runner. add_label ( label) ;
394
465
}
395
466
396
467
let filename = file. to_string_lossy ( ) . to_string ( ) ;
397
468
let test_case_name = filename. replace ( [ '/' , ' ' , '.' , '-' ] , "_" ) ;
469
+ let mut failed = false ;
398
470
let case = match run_test_file ( & mut std:: io:: stdout ( ) , runner, & file) . await {
399
471
Ok ( duration) => {
400
472
let mut case = TestCase :: new ( test_case_name, TestCaseStatus :: success ( ) ) ;
@@ -404,7 +476,12 @@ async fn run_serial(
404
476
case
405
477
}
406
478
Err ( e) => {
407
- println ! ( "{}\n \n {:?}" , style( "[FAILED]" ) . red( ) . bold( ) , e) ;
479
+ failed = true ;
480
+ let err = format ! ( "{:?}" , e) ;
481
+ if err. contains ( "Connection refused" ) {
482
+ connection_refused = true ;
483
+ }
484
+ println ! ( "{}\n \n {}" , style( "[FAILED]" ) . red( ) . bold( ) , err) ;
408
485
println ! ( ) ;
409
486
failed_case. push ( filename. clone ( ) ) ;
410
487
let mut status = TestCaseStatus :: non_success ( NonSuccessKind :: Failure ) ;
@@ -419,6 +496,27 @@ async fn run_serial(
419
496
}
420
497
} ;
421
498
test_suite. add_test_case ( case) ;
499
+ if connection_refused {
500
+ eprintln ! ( "Connection refused. The server may be down. Exiting..." ) ;
501
+ break ;
502
+ }
503
+ if fail_fast && failed {
504
+ println ! ( "early exit after failure..." ) ;
505
+ break ;
506
+ }
507
+ }
508
+ for file in files {
509
+ let filename = file. to_string_lossy ( ) . to_string ( ) ;
510
+ let test_case_name = filename. replace ( [ '/' , ' ' , '.' , '-' ] , "_" ) ;
511
+ let mut case = TestCase :: new ( test_case_name, TestCaseStatus :: skipped ( ) ) ;
512
+ case. set_time ( Duration :: from_millis ( 0 ) ) ;
513
+ case. set_timestamp ( Local :: now ( ) ) ;
514
+ case. set_classname ( junit. as_deref ( ) . unwrap_or_default ( ) ) ;
515
+ test_suite. add_test_case ( case) ;
516
+ skipped_case. push ( filename. clone ( ) ) ;
517
+ }
518
+ if !skipped_case. is_empty ( ) {
519
+ println ! ( "some test case skipped:\n {:#?}" , skipped_case) ;
422
520
}
423
521
424
522
if !failed_case. is_empty ( ) {
@@ -750,7 +848,8 @@ async fn update_record<M: MakeConnection>(
750
848
& record_output,
751
849
"\t " ,
752
850
default_validator,
753
- strict_column_validator,
851
+ default_normalizer,
852
+ default_column_validator,
754
853
) {
755
854
Some ( new_record) => {
756
855
writeln ! ( outfile, "{new_record}" ) ?;
0 commit comments