Skip to content

Commit 2441dff

Browse files
authored
refactor: return record output in run (#219)
1 parent 17d81db commit 2441dff

File tree

6 files changed

+47
-28
lines changed

6 files changed

+47
-28
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
* runner: `RecordOutput` is now returned by `Runner::run` (or `Runner::run_async`). This allows users to access the output of each record, or check whether the record is skipped.
11+
1012
## [0.20.6] - 2024-06-21
1113

1214
* runner: add logs for `system` command (with target `sqllogictest::system_command`) for ease of debugging.

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ resolver = "2"
33
members = ["sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"]
44

55
[workspace.package]
6-
version = "0.20.6"
6+
version = "0.21.0"
77
edition = "2021"
88
homepage = "https://github.com/risinglightdb/sqllogictest-rs"
99
keywords = ["sql", "database", "parser", "cli"]

sqllogictest-bin/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ glob = "0.3"
2323
itertools = "0.13"
2424
quick-junit = { version = "0.4" }
2525
rand = "0.8"
26-
sqllogictest = { path = "../sqllogictest", version = "0.20" }
27-
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.20" }
26+
sqllogictest = { path = "../sqllogictest", version = "0.21" }
27+
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.21" }
2828
tokio = { version = "1", features = [
2929
"rt",
3030
"rt-multi-thread",

sqllogictest-engines/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ postgres-types = { version = "0.2.5", features = ["derive", "with-chrono-0_4"] }
1919
rust_decimal = { version = "1.30.0", features = ["tokio-pg"] }
2020
serde = { version = "1", features = ["derive"] }
2121
serde_json = "1"
22-
sqllogictest = { path = "../sqllogictest", version = "0.20" }
22+
sqllogictest = { path = "../sqllogictest", version = "0.21" }
2323
thiserror = "1"
2424
tokio = { version = "1", features = [
2525
"rt",

sqllogictest/src/runner.rs

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,22 @@ use crate::{ColumnType, Connections, MakeConnection};
2323
/// Type-erased error type.
2424
type AnyError = Arc<dyn std::error::Error + Send + Sync>;
2525

26+
/// Output of a record.
2627
#[derive(Debug, Clone)]
2728
#[non_exhaustive]
2829
pub enum RecordOutput<T: ColumnType> {
30+
/// No output. Occurs when the record is skipped or not a `query`, `statement`, or `system`
31+
/// command.
2932
Nothing,
33+
/// The output of a `query`.
3034
Query {
3135
types: Vec<T>,
3236
rows: Vec<Vec<String>>,
3337
error: Option<AnyError>,
3438
},
35-
Statement {
36-
count: u64,
37-
error: Option<AnyError>,
38-
},
39+
/// The output of a `statement`.
40+
Statement { count: u64, error: Option<AnyError> },
41+
/// The output of a `system` command.
3942
#[non_exhaustive]
4043
System {
4144
stdout: Option<String>,
@@ -833,10 +836,13 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
833836
}
834837

835838
/// Run a single record.
836-
pub async fn run_async(&mut self, record: Record<D::ColumnType>) -> Result<(), TestError> {
839+
pub async fn run_async(
840+
&mut self,
841+
record: Record<D::ColumnType>,
842+
) -> Result<RecordOutput<D::ColumnType>, TestError> {
837843
let result = self.apply_record(record.clone()).await;
838844

839-
match (record, result) {
845+
match (record, &result) {
840846
(_, RecordOutput::Nothing) => {}
841847
// Tolerate the mismatched return type...
842848
(
@@ -894,7 +900,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
894900
.at(loc))
895901
}
896902
(None, StatementExpect::Count(expected_count)) => {
897-
if expected_count != count {
903+
if expected_count != *count {
898904
return Err(TestErrorKind::StatementResultMismatch {
899905
sql,
900906
expected: expected_count,
@@ -908,7 +914,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
908914
if !expected_error.is_match(&e.to_string()) {
909915
return Err(TestErrorKind::ErrorMismatch {
910916
sql,
911-
err: Arc::new(e),
917+
err: Arc::clone(e),
912918
expected_err: expected_error.to_string(),
913919
kind: RecordKind::Statement,
914920
}
@@ -918,7 +924,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
918924
(Some(e), StatementExpect::Count(_) | StatementExpect::Ok) => {
919925
return Err(TestErrorKind::Fail {
920926
sql,
921-
err: Arc::new(e),
927+
err: Arc::clone(e),
922928
kind: RecordKind::Statement,
923929
}
924930
.at(loc));
@@ -946,7 +952,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
946952
if !expected_error.is_match(&e.to_string()) {
947953
return Err(TestErrorKind::ErrorMismatch {
948954
sql,
949-
err: Arc::new(e),
955+
err: Arc::clone(e),
950956
expected_err: expected_error.to_string(),
951957
kind: RecordKind::Query,
952958
}
@@ -956,7 +962,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
956962
(Some(e), QueryExpect::Results { .. }) => {
957963
return Err(TestErrorKind::Fail {
958964
sql,
959-
err: Arc::new(e),
965+
err: Arc::clone(e),
960966
kind: RecordKind::Query,
961967
}
962968
.at(loc));
@@ -969,7 +975,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
969975
..
970976
},
971977
) => {
972-
if !(self.column_type_validator)(&types, &expected_types) {
978+
if !(self.column_type_validator)(types, &expected_types) {
973979
return Err(TestErrorKind::QueryResultColumnsMismatch {
974980
sql,
975981
expected: expected_types.iter().map(|c| c.to_char()).join(""),
@@ -978,11 +984,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
978984
.at(loc));
979985
}
980986

981-
if !(self.validator)(&rows, &expected_results) {
982-
let output_rows = rows
983-
.into_iter()
984-
.map(|strs| strs.iter().join(" "))
985-
.collect_vec();
987+
if !(self.validator)(rows, &expected_results) {
988+
let output_rows =
989+
rows.iter().map(|strs| strs.iter().join(" ")).collect_vec();
986990
return Err(TestErrorKind::QueryResultMismatch {
987991
sql,
988992
expected: expected_results.join("\n"),
@@ -1006,12 +1010,16 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
10061010
},
10071011
) => {
10081012
if let Some(err) = error {
1009-
return Err(TestErrorKind::SystemFail { command, err }.at(loc));
1013+
return Err(TestErrorKind::SystemFail {
1014+
command,
1015+
err: Arc::clone(err),
1016+
}
1017+
.at(loc));
10101018
}
10111019
match (expected_stdout, actual_stdout) {
10121020
(None, _) => {}
10131021
(Some(expected_stdout), actual_stdout) => {
1014-
let actual_stdout = actual_stdout.unwrap_or_default();
1022+
let actual_stdout = actual_stdout.clone().unwrap_or_default();
10151023
// TODO: support newlines contained in expected_stdout
10161024
if expected_stdout != actual_stdout.trim() {
10171025
return Err(TestErrorKind::SystemStdoutMismatch {
@@ -1027,17 +1035,24 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
10271035
_ => unreachable!(),
10281036
}
10291037

1030-
Ok(())
1038+
Ok(result)
10311039
}
10321040

10331041
/// Run a single record.
1034-
pub fn run(&mut self, record: Record<D::ColumnType>) -> Result<(), TestError> {
1042+
///
1043+
/// Returns the output of the record if successful.
1044+
pub fn run(
1045+
&mut self,
1046+
record: Record<D::ColumnType>,
1047+
) -> Result<RecordOutput<D::ColumnType>, TestError> {
10351048
futures::executor::block_on(self.run_async(record))
10361049
}
10371050

10381051
/// Run multiple records.
10391052
///
10401053
/// The runner will stop early once a halt record is seen.
1054+
///
1055+
/// To acquire the result of each record, manually call `run_async` for each record instead.
10411056
pub async fn run_multi_async(
10421057
&mut self,
10431058
records: impl IntoIterator<Item = Record<D::ColumnType>>,
@@ -1054,6 +1069,8 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
10541069
/// Run multiple records.
10551070
///
10561071
/// The runner will stop early once a halt record is seen.
1072+
///
1073+
/// To acquire the result of each record, manually call `run` for each record instead.
10571074
pub fn run_multi(
10581075
&mut self,
10591076
records: impl IntoIterator<Item = Record<D::ColumnType>>,

0 commit comments

Comments
 (0)