Skip to content

Commit a594375

Browse files
authored
Merge pull request #97 from benesch/update
Support UPDATE statements
2 parents 6fceba8 + 1cef68e commit a594375

File tree

3 files changed

+80
-3
lines changed

3 files changed

+80
-3
lines changed

src/sqlast/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ impl ToString for SQLStatement {
454454
} => {
455455
let mut s = format!("UPDATE {}", table_name.to_string());
456456
if !assignments.is_empty() {
457+
s += " SET ";
457458
s += &comma_separated_string(assignments);
458459
}
459460
if let Some(selection) = selection {
@@ -554,13 +555,13 @@ impl ToString for SQLObjectName {
554555
/// SQL assignment `foo = expr` as used in SQLUpdate
555556
#[derive(Debug, Clone, PartialEq, Hash)]
556557
pub struct SQLAssignment {
557-
id: SQLIdent,
558-
value: ASTNode,
558+
pub id: SQLIdent,
559+
pub value: ASTNode,
559560
}
560561

561562
impl ToString for SQLAssignment {
562563
fn to_string(&self) -> String {
563-
format!("SET {} = {}", self.id, self.value.to_string())
564+
format!("{} = {}", self.id, self.value.to_string())
564565
}
565566
}
566567

src/sqlparser.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ impl Parser {
117117
"DROP" => Ok(self.parse_drop()?),
118118
"DELETE" => Ok(self.parse_delete()?),
119119
"INSERT" => Ok(self.parse_insert()?),
120+
"UPDATE" => Ok(self.parse_update()?),
120121
"ALTER" => Ok(self.parse_alter()?),
121122
"COPY" => Ok(self.parse_copy()?),
122123
_ => parser_err!(format!(
@@ -1615,6 +1616,31 @@ impl Parser {
16151616
})
16161617
}
16171618

1619+
pub fn parse_update(&mut self) -> Result<SQLStatement, ParserError> {
1620+
let table_name = self.parse_object_name()?;
1621+
self.expect_keyword("SET")?;
1622+
let mut assignments = vec![];
1623+
loop {
1624+
let id = self.parse_identifier()?;
1625+
self.expect_token(&Token::Eq)?;
1626+
let value = self.parse_expr()?;
1627+
assignments.push(SQLAssignment { id, value });
1628+
if !self.consume_token(&Token::Comma) {
1629+
break;
1630+
}
1631+
}
1632+
let selection = if self.parse_keyword("WHERE") {
1633+
Some(self.parse_expr()?)
1634+
} else {
1635+
None
1636+
};
1637+
Ok(SQLStatement::SQLUpdate {
1638+
table_name,
1639+
assignments,
1640+
selection,
1641+
})
1642+
}
1643+
16181644
/// Parse a comma-delimited list of SQL expressions
16191645
pub fn parse_expr_list(&mut self) -> Result<Vec<ASTNode>, ParserError> {
16201646
let mut expr_list: Vec<ASTNode> = vec![];

tests/sqlparser_common.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,56 @@ fn parse_insert_invalid() {
8181
);
8282
}
8383

84+
#[test]
85+
fn parse_update() {
86+
let sql = "UPDATE t SET a = 1, b = 2, c = 3 WHERE d";
87+
match verified_stmt(sql) {
88+
SQLStatement::SQLUpdate {
89+
table_name,
90+
assignments,
91+
selection,
92+
..
93+
} => {
94+
assert_eq!(table_name.to_string(), "t".to_string());
95+
assert_eq!(
96+
assignments,
97+
vec![
98+
SQLAssignment {
99+
id: "a".into(),
100+
value: ASTNode::SQLValue(Value::Long(1)),
101+
},
102+
SQLAssignment {
103+
id: "b".into(),
104+
value: ASTNode::SQLValue(Value::Long(2)),
105+
},
106+
SQLAssignment {
107+
id: "c".into(),
108+
value: ASTNode::SQLValue(Value::Long(3)),
109+
},
110+
]
111+
);
112+
assert_eq!(selection.unwrap(), ASTNode::SQLIdentifier("d".into()));
113+
}
114+
_ => unreachable!(),
115+
}
116+
117+
verified_stmt("UPDATE t SET a = 1, a = 2, a = 3");
118+
119+
let sql = "UPDATE t WHERE 1";
120+
let res = parse_sql_statements(sql);
121+
assert_eq!(
122+
ParserError::ParserError("Expected SET, found: WHERE".to_string()),
123+
res.unwrap_err()
124+
);
125+
126+
let sql = "UPDATE t SET a = 1 extrabadstuff";
127+
let res = parse_sql_statements(sql);
128+
assert_eq!(
129+
ParserError::ParserError("Expected end of statement, found: extrabadstuff".to_string()),
130+
res.unwrap_err()
131+
);
132+
}
133+
84134
#[test]
85135
fn parse_invalid_table_name() {
86136
let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name);

0 commit comments

Comments
 (0)