Skip to content

Commit b8ba188

Browse files
authored
Merge pull request #91 from benesch/from-values
Improve VALUES-related syntax parsing
2 parents 057518b + 14e07eb commit b8ba188

File tree

4 files changed

+95
-45
lines changed

4 files changed

+95
-45
lines changed

src/sqlast/mod.rs

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,27 @@ mod sql_operator;
2020
mod sqltype;
2121
mod value;
2222

23+
use std::ops::Deref;
24+
2325
pub use self::ddl::{AlterTableOperation, TableConstraint};
2426
pub use self::query::{
2527
Cte, Fetch, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect,
26-
SQLSelectItem, SQLSetExpr, SQLSetOperator, TableFactor,
28+
SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableFactor,
2729
};
2830
pub use self::sqltype::SQLType;
2931
pub use self::value::Value;
3032

3133
pub use self::sql_operator::SQLOperator;
3234

3335
/// Like `vec.join(", ")`, but for any types implementing ToString.
34-
fn comma_separated_string<T: ToString>(vec: &[T]) -> String {
35-
vec.iter()
36-
.map(T::to_string)
36+
fn comma_separated_string<I>(iter: I) -> String
37+
where
38+
I: IntoIterator,
39+
I::Item: Deref,
40+
<I::Item as Deref>::Target: ToString,
41+
{
42+
iter.into_iter()
43+
.map(|t| t.deref().to_string())
3744
.collect::<Vec<String>>()
3845
.join(", ")
3946
}
@@ -339,8 +346,8 @@ pub enum SQLStatement {
339346
table_name: SQLObjectName,
340347
/// COLUMNS
341348
columns: Vec<SQLIdent>,
342-
/// VALUES (vector of rows to insert)
343-
values: Vec<Vec<ASTNode>>,
349+
/// A SQL query that specifies what to insert
350+
source: Box<SQLQuery>,
344351
},
345352
SQLCopy {
346353
/// TABLE
@@ -406,22 +413,13 @@ impl ToString for SQLStatement {
406413
SQLStatement::SQLInsert {
407414
table_name,
408415
columns,
409-
values,
416+
source,
410417
} => {
411-
let mut s = format!("INSERT INTO {}", table_name.to_string());
418+
let mut s = format!("INSERT INTO {} ", table_name.to_string());
412419
if !columns.is_empty() {
413-
s += &format!(" ({})", columns.join(", "));
414-
}
415-
if !values.is_empty() {
416-
s += &format!(
417-
" VALUES({})",
418-
values
419-
.iter()
420-
.map(|row| comma_separated_string(row))
421-
.collect::<Vec<String>>()
422-
.join(", ")
423-
);
420+
s += &format!("({}) ", columns.join(", "));
424421
}
422+
s += &source.to_string();
425423
s
426424
}
427425
SQLStatement::SQLCopy {
@@ -523,7 +521,7 @@ impl ToString for SQLStatement {
523521
"DROP {}{} {}{}",
524522
object_type.to_string(),
525523
if *if_exists { " IF EXISTS" } else { "" },
526-
comma_separated_string(&names),
524+
comma_separated_string(names),
527525
if *cascade { " CASCADE" } else { "" },
528526
),
529527
}

src/sqlast/query.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,16 @@ pub enum SQLSetExpr {
5858
left: Box<SQLSetExpr>,
5959
right: Box<SQLSetExpr>,
6060
},
61-
// TODO: ANSI SQL supports `TABLE` and `VALUES` here.
61+
Values(SQLValues),
62+
// TODO: ANSI SQL supports `TABLE` here.
6263
}
6364

6465
impl ToString for SQLSetExpr {
6566
fn to_string(&self) -> String {
6667
match self {
6768
SQLSetExpr::Select(s) => s.to_string(),
6869
SQLSetExpr::Query(q) => format!("({})", q.to_string()),
70+
SQLSetExpr::Values(v) => v.to_string(),
6971
SQLSetExpr::SetOperation {
7072
left,
7173
right,
@@ -364,3 +366,16 @@ impl ToString for Fetch {
364366
}
365367
}
366368
}
369+
370+
#[derive(Debug, Clone, PartialEq)]
371+
pub struct SQLValues(pub Vec<Vec<ASTNode>>);
372+
373+
impl ToString for SQLValues {
374+
fn to_string(&self) -> String {
375+
let rows = self
376+
.0
377+
.iter()
378+
.map(|row| format!("({})", comma_separated_string(row)));
379+
format!("VALUES {}", comma_separated_string(rows))
380+
}
381+
}

src/sqlparser.rs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,8 @@ impl Parser {
13611361
let subquery = self.parse_query()?;
13621362
self.expect_token(&Token::RParen)?;
13631363
SQLSetExpr::Query(Box::new(subquery))
1364+
} else if self.parse_keyword("VALUES") {
1365+
SQLSetExpr::Values(self.parse_values()?)
13641366
} else {
13651367
return self.expected("SELECT or a subquery in the query body", self.peek_token());
13661368
};
@@ -1572,14 +1574,11 @@ impl Parser {
15721574
self.expect_keyword("INTO")?;
15731575
let table_name = self.parse_object_name()?;
15741576
let columns = self.parse_parenthesized_column_list(Optional)?;
1575-
self.expect_keyword("VALUES")?;
1576-
self.expect_token(&Token::LParen)?;
1577-
let values = self.parse_expr_list()?;
1578-
self.expect_token(&Token::RParen)?;
1577+
let source = Box::new(self.parse_query()?);
15791578
Ok(SQLStatement::SQLInsert {
15801579
table_name,
15811580
columns,
1582-
values: vec![values],
1581+
source,
15831582
})
15841583
}
15851584

@@ -1697,6 +1696,20 @@ impl Parser {
16971696
quantity,
16981697
})
16991698
}
1699+
1700+
pub fn parse_values(&mut self) -> Result<SQLValues, ParserError> {
1701+
let mut values = vec![];
1702+
loop {
1703+
self.expect_token(&Token::LParen)?;
1704+
values.push(self.parse_expr_list()?);
1705+
self.expect_token(&Token::RParen)?;
1706+
match self.peek_token() {
1707+
Some(Token::Comma) => self.next_token(),
1708+
_ => break,
1709+
};
1710+
}
1711+
Ok(SQLValues(values))
1712+
}
17001713
}
17011714

17021715
impl SQLWord {

tests/sqlparser_common.rs

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,61 @@ use sqlparser::test_utils::{all_dialects, expr_from_projection, only};
1414

1515
#[test]
1616
fn parse_insert_values() {
17-
let sql = "INSERT INTO customer VALUES(1, 2, 3)";
18-
check_one(sql, "customer", vec![]);
17+
let row = vec![
18+
ASTNode::SQLValue(Value::Long(1)),
19+
ASTNode::SQLValue(Value::Long(2)),
20+
ASTNode::SQLValue(Value::Long(3)),
21+
];
22+
let rows1 = vec![row.clone()];
23+
let rows2 = vec![row.clone(), row];
1924

20-
let sql = "INSERT INTO public.customer VALUES(1, 2, 3)";
21-
check_one(sql, "public.customer", vec![]);
25+
let sql = "INSERT INTO customer VALUES (1, 2, 3)";
26+
check_one(sql, "customer", &[], &rows1);
2227

23-
let sql = "INSERT INTO db.public.customer VALUES(1, 2, 3)";
24-
check_one(sql, "db.public.customer", vec![]);
28+
let sql = "INSERT INTO customer VALUES (1, 2, 3), (1, 2, 3)";
29+
check_one(sql, "customer", &[], &rows2);
2530

26-
let sql = "INSERT INTO public.customer (id, name, active) VALUES(1, 2, 3)";
31+
let sql = "INSERT INTO public.customer VALUES (1, 2, 3)";
32+
check_one(sql, "public.customer", &[], &rows1);
33+
34+
let sql = "INSERT INTO db.public.customer VALUES (1, 2, 3)";
35+
check_one(sql, "db.public.customer", &[], &rows1);
36+
37+
let sql = "INSERT INTO public.customer (id, name, active) VALUES (1, 2, 3)";
2738
check_one(
2839
sql,
2940
"public.customer",
30-
vec!["id".to_string(), "name".to_string(), "active".to_string()],
41+
&["id".to_string(), "name".to_string(), "active".to_string()],
42+
&rows1,
3143
);
3244

33-
fn check_one(sql: &str, expected_table_name: &str, expected_columns: Vec<String>) {
45+
fn check_one(
46+
sql: &str,
47+
expected_table_name: &str,
48+
expected_columns: &[String],
49+
expected_rows: &[Vec<ASTNode>],
50+
) {
3451
match verified_stmt(sql) {
3552
SQLStatement::SQLInsert {
3653
table_name,
3754
columns,
38-
values,
55+
source,
3956
..
4057
} => {
4158
assert_eq!(table_name.to_string(), expected_table_name);
4259
assert_eq!(columns, expected_columns);
43-
assert_eq!(
44-
vec![vec![
45-
ASTNode::SQLValue(Value::Long(1)),
46-
ASTNode::SQLValue(Value::Long(2)),
47-
ASTNode::SQLValue(Value::Long(3))
48-
]],
49-
values
50-
);
60+
match &source.body {
61+
SQLSetExpr::Values(SQLValues(values)) => {
62+
assert_eq!(values.as_slice(), expected_rows)
63+
}
64+
_ => unreachable!(),
65+
}
5166
}
5267
_ => unreachable!(),
5368
}
5469
}
70+
71+
verified_stmt("INSERT INTO customer WITH foo AS (SELECT 1) SELECT * FROM foo UNION VALUES (1)");
5572
}
5673

5774
#[test]
@@ -1383,6 +1400,13 @@ fn parse_union() {
13831400
verified_stmt("SELECT foo FROM tab UNION SELECT bar FROM TAB");
13841401
}
13851402

1403+
#[test]
1404+
fn parse_values() {
1405+
verified_stmt("SELECT * FROM (VALUES (1), (2), (3))");
1406+
verified_stmt("SELECT * FROM (VALUES (1), (2), (3)), (VALUES (1, 2, 3))");
1407+
verified_stmt("SELECT * FROM (VALUES (1)) UNION VALUES (1)");
1408+
}
1409+
13861410
#[test]
13871411
fn parse_multiple_statements() {
13881412
fn test_with(sql1: &str, sql2_kw: &str, sql2_rest: &str) {
@@ -1416,7 +1440,7 @@ fn parse_multiple_statements() {
14161440
" cte AS (SELECT 1 AS s) SELECT bar",
14171441
);
14181442
test_with("DELETE FROM foo", "SELECT", " bar");
1419-
test_with("INSERT INTO foo VALUES(1)", "SELECT", " bar");
1443+
test_with("INSERT INTO foo VALUES (1)", "SELECT", " bar");
14201444
test_with("CREATE TABLE foo (baz int)", "SELECT", " bar");
14211445
// Make sure that empty statements do not cause an error:
14221446
let res = parse_sql_statements(";;");

0 commit comments

Comments
 (0)