@@ -23,19 +23,22 @@ use crate::{ColumnType, Connections, MakeConnection};
23
23
/// Type-erased error type.
24
24
type AnyError = Arc < dyn std:: error:: Error + Send + Sync > ;
25
25
26
+ /// Output of a record.
26
27
#[ derive( Debug , Clone ) ]
27
28
#[ non_exhaustive]
28
29
pub enum RecordOutput < T : ColumnType > {
30
+ /// No output. Occurs when the record is skipped or not a `query`, `statement`, or `system`
31
+ /// command.
29
32
Nothing ,
33
+ /// The output of a `query`.
30
34
Query {
31
35
types : Vec < T > ,
32
36
rows : Vec < Vec < String > > ,
33
37
error : Option < AnyError > ,
34
38
} ,
35
- Statement {
36
- count : u64 ,
37
- error : Option < AnyError > ,
38
- } ,
39
+ /// The output of a `statement`.
40
+ Statement { count : u64 , error : Option < AnyError > } ,
41
+ /// The output of a `system` command.
39
42
#[ non_exhaustive]
40
43
System {
41
44
stdout : Option < String > ,
@@ -833,10 +836,13 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
833
836
}
834
837
835
838
/// Run a single record.
836
- pub async fn run_async ( & mut self , record : Record < D :: ColumnType > ) -> Result < ( ) , TestError > {
839
+ pub async fn run_async (
840
+ & mut self ,
841
+ record : Record < D :: ColumnType > ,
842
+ ) -> Result < RecordOutput < D :: ColumnType > , TestError > {
837
843
let result = self . apply_record ( record. clone ( ) ) . await ;
838
844
839
- match ( record, result) {
845
+ match ( record, & result) {
840
846
( _, RecordOutput :: Nothing ) => { }
841
847
// Tolerate the mismatched return type...
842
848
(
@@ -894,7 +900,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
894
900
. at ( loc) )
895
901
}
896
902
( None , StatementExpect :: Count ( expected_count) ) => {
897
- if expected_count != count {
903
+ if expected_count != * count {
898
904
return Err ( TestErrorKind :: StatementResultMismatch {
899
905
sql,
900
906
expected : expected_count,
@@ -908,7 +914,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
908
914
if !expected_error. is_match ( & e. to_string ( ) ) {
909
915
return Err ( TestErrorKind :: ErrorMismatch {
910
916
sql,
911
- err : Arc :: new ( e) ,
917
+ err : Arc :: clone ( e) ,
912
918
expected_err : expected_error. to_string ( ) ,
913
919
kind : RecordKind :: Statement ,
914
920
}
@@ -918,7 +924,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
918
924
( Some ( e) , StatementExpect :: Count ( _) | StatementExpect :: Ok ) => {
919
925
return Err ( TestErrorKind :: Fail {
920
926
sql,
921
- err : Arc :: new ( e) ,
927
+ err : Arc :: clone ( e) ,
922
928
kind : RecordKind :: Statement ,
923
929
}
924
930
. at ( loc) ) ;
@@ -946,7 +952,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
946
952
if !expected_error. is_match ( & e. to_string ( ) ) {
947
953
return Err ( TestErrorKind :: ErrorMismatch {
948
954
sql,
949
- err : Arc :: new ( e) ,
955
+ err : Arc :: clone ( e) ,
950
956
expected_err : expected_error. to_string ( ) ,
951
957
kind : RecordKind :: Query ,
952
958
}
@@ -956,7 +962,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
956
962
( Some ( e) , QueryExpect :: Results { .. } ) => {
957
963
return Err ( TestErrorKind :: Fail {
958
964
sql,
959
- err : Arc :: new ( e) ,
965
+ err : Arc :: clone ( e) ,
960
966
kind : RecordKind :: Query ,
961
967
}
962
968
. at ( loc) ) ;
@@ -969,7 +975,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
969
975
..
970
976
} ,
971
977
) => {
972
- if !( self . column_type_validator ) ( & types, & expected_types) {
978
+ if !( self . column_type_validator ) ( types, & expected_types) {
973
979
return Err ( TestErrorKind :: QueryResultColumnsMismatch {
974
980
sql,
975
981
expected : expected_types. iter ( ) . map ( |c| c. to_char ( ) ) . join ( "" ) ,
@@ -978,11 +984,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
978
984
. at ( loc) ) ;
979
985
}
980
986
981
- if !( self . validator ) ( & rows, & expected_results) {
982
- let output_rows = rows
983
- . into_iter ( )
984
- . map ( |strs| strs. iter ( ) . join ( " " ) )
985
- . collect_vec ( ) ;
987
+ if !( self . validator ) ( rows, & expected_results) {
988
+ let output_rows =
989
+ rows. iter ( ) . map ( |strs| strs. iter ( ) . join ( " " ) ) . collect_vec ( ) ;
986
990
return Err ( TestErrorKind :: QueryResultMismatch {
987
991
sql,
988
992
expected : expected_results. join ( "\n " ) ,
@@ -1006,12 +1010,16 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
1006
1010
} ,
1007
1011
) => {
1008
1012
if let Some ( err) = error {
1009
- return Err ( TestErrorKind :: SystemFail { command, err } . at ( loc) ) ;
1013
+ return Err ( TestErrorKind :: SystemFail {
1014
+ command,
1015
+ err : Arc :: clone ( err) ,
1016
+ }
1017
+ . at ( loc) ) ;
1010
1018
}
1011
1019
match ( expected_stdout, actual_stdout) {
1012
1020
( None , _) => { }
1013
1021
( Some ( expected_stdout) , actual_stdout) => {
1014
- let actual_stdout = actual_stdout. unwrap_or_default ( ) ;
1022
+ let actual_stdout = actual_stdout. clone ( ) . unwrap_or_default ( ) ;
1015
1023
// TODO: support newlines contained in expected_stdout
1016
1024
if expected_stdout != actual_stdout. trim ( ) {
1017
1025
return Err ( TestErrorKind :: SystemStdoutMismatch {
@@ -1027,17 +1035,24 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
1027
1035
_ => unreachable ! ( ) ,
1028
1036
}
1029
1037
1030
- Ok ( ( ) )
1038
+ Ok ( result )
1031
1039
}
1032
1040
1033
1041
/// Run a single record.
1034
- pub fn run ( & mut self , record : Record < D :: ColumnType > ) -> Result < ( ) , TestError > {
1042
+ ///
1043
+ /// Returns the output of the record if successful.
1044
+ pub fn run (
1045
+ & mut self ,
1046
+ record : Record < D :: ColumnType > ,
1047
+ ) -> Result < RecordOutput < D :: ColumnType > , TestError > {
1035
1048
futures:: executor:: block_on ( self . run_async ( record) )
1036
1049
}
1037
1050
1038
1051
/// Run multiple records.
1039
1052
///
1040
1053
/// The runner will stop early once a halt record is seen.
1054
+ ///
1055
+ /// To acquire the result of each record, manually call `run_async` for each record instead.
1041
1056
pub async fn run_multi_async (
1042
1057
& mut self ,
1043
1058
records : impl IntoIterator < Item = Record < D :: ColumnType > > ,
@@ -1054,6 +1069,8 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
1054
1069
/// Run multiple records.
1055
1070
///
1056
1071
/// The runner will stop early once a halt record is seen.
1072
+ ///
1073
+ /// To acquire the result of each record, manually call `run` for each record instead.
1057
1074
pub fn run_multi (
1058
1075
& mut self ,
1059
1076
records : impl IntoIterator < Item = Record < D :: ColumnType > > ,
0 commit comments