Skip to content

Commit 2755a56

Browse files
authored
Merge pull request #2 from featuremesh/fix/error-comparing-columns
skip checking col names if not provided
2 parents 06c3bb9 + 0ecdf9d commit 2755a56

File tree

2 files changed

+77
-47
lines changed

2 files changed

+77
-47
lines changed

sqllogictest/src/parser.rs

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,8 @@ pub enum ParseErrorKind {
585585
UnexpectedEOF,
586586
#[error("invalid sort mode: {0:?}")]
587587
InvalidSortMode(String),
588+
#[error("invalid column: {0:?}")]
589+
InvalidColumn(String),
588590
#[error("invalid line: {0:?}")]
589591
InvalidLine(String),
590592
#[error("invalid type character: {0:?} in type string")]
@@ -747,8 +749,8 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
747749
.map_err(|e| e.at(loc.clone()))?;
748750
QueryExpect::Error(error)
749751
}
750-
[type_str, res @ ..] => {
751-
let cols = parse_cols(&loc, type_str)?;
752+
[col_str, res @ ..] => {
753+
let cols = parse_cols(col_str, &loc)?;
752754
let sort_mode = res
753755
.first()
754756
.map(|&s| SortMode::try_from_str(s))
@@ -840,43 +842,36 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
840842
Ok(records)
841843
}
842844

843-
fn parse_cols<T: ColumnType>(
844-
loc: &Location,
845-
type_str: &&str,
846-
) -> Result<Vec<Column<T>>, ParseError> {
847-
// Check if contains ':' or ',' to determine format
848-
if type_str.contains(':') {
849-
// Parse "c1:I,name:I" format
850-
type_str
845+
fn parse_cols<T: ColumnType>(col_str: &&str, loc: &Location) -> Result<Vec<Column<T>>, ParseError> {
846+
fn parse_type_char<T: ColumnType>(
847+
part: &str,
848+
name: String,
849+
loc: &Location,
850+
) -> Result<Column<T>, ParseError> {
851+
let type_char = part
852+
.trim()
853+
.chars()
854+
.next()
855+
.ok_or_else(|| ParseErrorKind::InvalidSortMode(part.into()).at(loc.clone()))?;
856+
857+
T::from_char(type_char)
858+
.map(|t| Column { name, r#type: t })
859+
.ok_or_else(|| ParseErrorKind::InvalidType(type_char).at(loc.clone()))
860+
}
861+
if col_str.contains(':') {
862+
col_str
851863
.split(',')
852-
.map(|part| {
853-
let (name, type_char) = part
854-
.split_once(':')
855-
.ok_or_else(|| ParseErrorKind::InvalidSortMode(part.into()).at(loc.clone()))?;
856-
857-
let type_char =
858-
type_char.trim().chars().next().ok_or_else(|| {
859-
ParseErrorKind::InvalidSortMode(part.into()).at(loc.clone())
860-
})?;
861-
862-
T::from_char(type_char)
863-
.map(|t| Column {
864-
name: name.trim().to_string(),
865-
r#type: t,
866-
})
867-
.ok_or_else(|| ParseErrorKind::InvalidType(type_char).at(loc.clone()))
864+
.map(|part| match part.split_once(':') {
865+
Some((name, type_str)) => parse_type_char(type_str, name.trim().to_string(), loc),
866+
None => parse_type_char(part, "?".into(), loc),
868867
})
869868
.try_collect()
870869
} else {
871-
// Original "III" format
872-
type_str
870+
col_str
873871
.chars()
874872
.map(|ch| {
875873
T::from_char(ch)
876-
.map(|t| Column {
877-
name: "?".into(),
878-
r#type: t,
879-
})
874+
.map(|t| Column::anon(t))
880875
.ok_or_else(|| ParseErrorKind::InvalidType(ch).at(loc.clone()))
881876
})
882877
.try_collect()
@@ -992,6 +987,20 @@ mod tests {
992987
use super::*;
993988
use crate::DefaultColumnType;
994989

990+
#[test]
991+
fn test_mixed_col_name_and_types() {
992+
let script = r#"
993+
query NAME:I,B
994+
select * from t
995+
----
996+
1 true
997+
998+
"#;
999+
1000+
let result = parse::<CustomColumnType>(script);
1001+
assert!(result.is_ok())
1002+
}
1003+
9951004
#[test]
9961005
fn test_trailing_comment() {
9971006
let script = "\

sqllogictest/src/runner.rs

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,19 @@ pub enum RecordOutput<T: ColumnType> {
4646
},
4747
}
4848

49-
#[derive(Debug, Clone, PartialEq)]
49+
#[derive(Debug, Clone)]
5050
pub struct Column<T: ColumnType> {
5151
pub name: String,
5252
pub r#type: T,
5353
}
5454

55+
impl<T: ColumnType> PartialEq for Column<T> {
56+
fn eq(&self, other: &Self) -> bool {
57+
self.r#type == other.r#type
58+
&& (other.name == "?" || self.name.to_lowercase() == other.name.to_lowercase())
59+
}
60+
}
61+
5562
impl<T: ColumnType + Display> Display for Column<T> {
5663
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5764
write!(f, "{}:{}", self.name, self.r#type.to_char())
@@ -530,11 +537,19 @@ pub fn strict_column_validator<T: ColumnType>(
530537
actual: &Vec<Column<T>>,
531538
expected: &Vec<Column<T>>,
532539
) -> bool {
533-
actual.len() == expected.len()
534-
&& !actual
535-
.iter()
536-
.zip(expected.iter())
537-
.any(|(actual_column, expected_column)| actual_column != expected_column)
540+
if actual.len() != expected.len() {
541+
return false;
542+
}
543+
!actual
544+
.iter()
545+
.zip(expected.iter())
546+
.any(|(actual_col, expected_col)| {
547+
let type_mismatch = actual_col.r#type != expected_col.r#type;
548+
let name_mismatch = expected_col.name != "?"
549+
&& actual_col.name.to_lowercase() != expected_col.name.to_lowercase();
550+
551+
type_mismatch || name_mismatch
552+
})
538553
}
539554

540555
/// Sqllogictest runner.
@@ -978,11 +993,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
978993
sql,
979994
expected,
980995
},
981-
RecordOutput::Query {
982-
cols: types,
983-
rows,
984-
error,
985-
},
996+
RecordOutput::Query { cols, rows, error },
986997
) => {
987998
match (error, expected) {
988999
(None, QueryExpect::Error(_)) => {
@@ -1014,16 +1025,26 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
10141025
(
10151026
None,
10161027
QueryExpect::Results {
1017-
cols: expected_types,
1028+
cols: expected_cols,
10181029
results: expected_results,
10191030
..
10201031
},
10211032
) => {
1022-
if !(self.column_type_validator)(types, &expected_types) {
1033+
if !(self.column_type_validator)(cols, &expected_cols) {
10231034
return Err(TestErrorKind::QueryResultColumnsMismatch {
10241035
sql,
1025-
expected: expected_types.iter().map(|c| c.to_char()).join(""),
1026-
actual: types.iter().map(|c| c.to_char()).join(""),
1036+
expected: expected_cols
1037+
.iter()
1038+
.map(|Column { name, r#type }| {
1039+
format!("{}:{}", name, r#type.to_char())
1040+
})
1041+
.join(","),
1042+
actual: cols
1043+
.iter()
1044+
.map(|Column { name, r#type }| {
1045+
format!("{}:{}", name, r#type.to_char())
1046+
})
1047+
.join(","),
10271048
}
10281049
.at(loc));
10291050
}

0 commit comments

Comments
 (0)