Skip to content

Commit 1b7266a

Browse files
committed
feat: check if column is nullable
1 parent adfaab4 commit 1b7266a

File tree

5 files changed

+400
-4
lines changed

5 files changed

+400
-4
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ jobs:
3636

3737
- run: cd parser && cargo fmt -- --check
3838

39+
- run: cd parser && cargo test
40+
3941
- uses: taiki-e/install-action@v2
4042
with:
4143

parser/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
pub mod nullable;
12
pub mod parameters;

parser/src/nullable.rs

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
use fallible_iterator::FallibleIterator;
2+
use sqlite3_parser::ast;
3+
use sqlite3_parser::lexer::sql::Parser;
4+
use wasm_bindgen::prelude::*;
5+
6+
#[wasm_bindgen]
7+
#[derive(Clone, Debug, PartialEq, Eq)]
8+
pub enum NullableResult {
9+
NotNull,
10+
Null,
11+
}
12+
13+
#[wasm_bindgen]
14+
pub fn is_column_nullable(column: &str, table_name: &str, query: &str) -> Option<NullableResult> {
15+
let mut parser = Parser::new(query.as_bytes());
16+
let cmd = parser.next().ok()??;
17+
18+
if let ast::Cmd::Stmt(ast::Stmt::Select(select)) = cmd {
19+
if let ast::OneSelect::Select {
20+
from: Some(from),
21+
where_clause,
22+
..
23+
} = select.body.select
24+
{
25+
let used_table_name = get_used_table_name(table_name, &from)?;
26+
27+
if let Some(where_clause) = where_clause {
28+
let result = test_expr(column, used_table_name, &where_clause);
29+
if result.is_some() {
30+
return result;
31+
}
32+
}
33+
34+
if let Some(joins) = &from.joins {
35+
// https://www.sqlite.org/lang_select.html#special_handling_of_cross_join_
36+
return joins.iter().find_map(|join| match join {
37+
ast::JoinedSelectTable {
38+
operator:
39+
ast::JoinOperator::Comma
40+
| ast::JoinOperator::TypedJoin(None)
41+
| ast::JoinOperator::TypedJoin(Some(ast::JoinType::INNER))
42+
| ast::JoinOperator::TypedJoin(Some(ast::JoinType::CROSS)),
43+
constraint: Some(ast::JoinConstraint::On(expr)),
44+
..
45+
} => test_expr(column, used_table_name, expr),
46+
_ => None,
47+
});
48+
}
49+
}
50+
}
51+
52+
None
53+
}
54+
55+
fn get_used_table_name<'a>(table_name: &str, from: &'a ast::FromClause) -> Option<&'a str> {
56+
if let Some(table) = &from.select {
57+
match table.as_ref() {
58+
ast::SelectTable::Table(name, as_name, _) if name.name.0 == table_name => {
59+
let used_table_name = match as_name {
60+
Some(ast::As::As(name)) => &name.0,
61+
Some(ast::As::Elided(name)) => &name.0,
62+
None => &name.name.0,
63+
};
64+
65+
return Some(used_table_name);
66+
}
67+
_ => {}
68+
};
69+
}
70+
71+
if let Some(joins) = &from.joins {
72+
// https://www.sqlite.org/lang_select.html#special_handling_of_cross_join_
73+
for join in joins {
74+
match join {
75+
ast::JoinedSelectTable {
76+
operator:
77+
ast::JoinOperator::Comma
78+
| ast::JoinOperator::TypedJoin(None)
79+
| ast::JoinOperator::TypedJoin(Some(ast::JoinType::INNER))
80+
| ast::JoinOperator::TypedJoin(Some(ast::JoinType::CROSS)),
81+
table: ast::SelectTable::Table(name, as_name, _),
82+
..
83+
} if name.name.0 == table_name => {
84+
let used_table_name = match as_name {
85+
Some(ast::As::As(name)) => &name.0,
86+
Some(ast::As::Elided(name)) => &name.0,
87+
None => &name.name.0,
88+
};
89+
90+
return Some(used_table_name);
91+
}
92+
_ => {}
93+
}
94+
}
95+
}
96+
97+
None
98+
}
99+
100+
fn test_expr(column_name: &str, table_name: &str, expr: &ast::Expr) -> Option<NullableResult> {
101+
match expr {
102+
ast::Expr::Binary(left, ast::Operator::Equals, right)
103+
| ast::Expr::Binary(left, ast::Operator::NotEquals, right)
104+
| ast::Expr::Binary(left, ast::Operator::Greater, right)
105+
| ast::Expr::Binary(left, ast::Operator::GreaterEquals, right)
106+
| ast::Expr::Binary(left, ast::Operator::Less, right)
107+
| ast::Expr::Binary(left, ast::Operator::LessEquals, right)
108+
if expr_matches_name(column_name, table_name, left)
109+
|| expr_matches_name(column_name, table_name, right) =>
110+
{
111+
return Some(NullableResult::NotNull);
112+
}
113+
ast::Expr::InList { lhs, .. } if expr_matches_name(column_name, table_name, lhs) => {
114+
return Some(NullableResult::NotNull);
115+
}
116+
// column is null
117+
ast::Expr::Binary(left, ast::Operator::Is, right)
118+
if **right == ast::Expr::Literal(ast::Literal::Null)
119+
&& expr_matches_name(column_name, table_name, left) =>
120+
{
121+
return Some(NullableResult::Null);
122+
}
123+
// null is column
124+
ast::Expr::Binary(left, ast::Operator::Is, right)
125+
if **left == ast::Expr::Literal(ast::Literal::Null)
126+
&& expr_matches_name(column_name, table_name, right) =>
127+
{
128+
return Some(NullableResult::Null);
129+
}
130+
// column is not null
131+
ast::Expr::Binary(left, ast::Operator::IsNot, right)
132+
if **right == ast::Expr::Literal(ast::Literal::Null)
133+
&& expr_matches_name(column_name, table_name, left) =>
134+
{
135+
return Some(NullableResult::NotNull);
136+
}
137+
// null is not column
138+
ast::Expr::Binary(left, ast::Operator::IsNot, right)
139+
if **left == ast::Expr::Literal(ast::Literal::Null)
140+
&& expr_matches_name(column_name, table_name, right) =>
141+
{
142+
return Some(NullableResult::NotNull);
143+
}
144+
// column notnull
145+
// column not null
146+
ast::Expr::NotNull(expr) if expr_matches_name(column_name, table_name, expr) => {
147+
return Some(NullableResult::NotNull);
148+
}
149+
// expr and expr
150+
ast::Expr::Binary(left, ast::Operator::And, right) => {
151+
return test_expr(column_name, table_name, left)
152+
.or_else(|| test_expr(column_name, table_name, right));
153+
}
154+
// expr or expr
155+
ast::Expr::Binary(left, ast::Operator::Or, right) => {
156+
let left = test_expr(column_name, table_name, left);
157+
let right = test_expr(column_name, table_name, right);
158+
return match (left, right) {
159+
(Some(NullableResult::NotNull), Some(NullableResult::NotNull)) => {
160+
Some(NullableResult::NotNull)
161+
}
162+
(Some(NullableResult::Null), Some(NullableResult::Null)) => {
163+
Some(NullableResult::Null)
164+
}
165+
_ => None,
166+
};
167+
}
168+
// (expr)
169+
ast::Expr::Parenthesized(exprs) => {
170+
let mut iter = exprs
171+
.iter()
172+
.map(|expr| test_expr(column_name, table_name, expr));
173+
let first = iter.next()?;
174+
175+
return iter.all(|x| x == first).then_some(first)?;
176+
}
177+
_ => {
178+
// println!("Unmatched expr: {:?}", expr);
179+
}
180+
}
181+
None
182+
}
183+
184+
fn expr_matches_name(column_name: &str, table_name: &str, expr: &ast::Expr) -> bool {
185+
match expr {
186+
ast::Expr::Id(id) => id.0 == column_name,
187+
ast::Expr::Qualified(name, id) => name.0 == table_name && id.0 == column_name,
188+
_ => false,
189+
}
190+
}
191+
192+
#[cfg(test)]
193+
mod tests {
194+
use super::*;
195+
196+
#[test]
197+
fn returns_none_when_nothing_is_provable() {
198+
assert_eq!(is_column_nullable("id", "foo", "select * from foo"), None);
199+
}
200+
201+
#[test]
202+
fn support_not_null() {
203+
assert_eq!(
204+
is_column_nullable("id", "foo", "select * from foo where id is not null"),
205+
Some(NullableResult::NotNull)
206+
);
207+
assert_eq!(
208+
is_column_nullable("id", "foo", "select * from foo where id notnull"),
209+
Some(NullableResult::NotNull)
210+
);
211+
}
212+
213+
#[test]
214+
fn support_aliased_table() {
215+
assert_eq!(
216+
is_column_nullable("id", "foo", "select * from foo f where id notnull"),
217+
Some(NullableResult::NotNull)
218+
);
219+
assert_eq!(
220+
is_column_nullable("id", "foo", "select * from foo f where f.id notnull"),
221+
Some(NullableResult::NotNull)
222+
);
223+
assert_eq!(
224+
is_column_nullable("id", "foo", "select * from foo f where foo.id notnull"),
225+
None
226+
);
227+
}
228+
229+
#[test]
230+
fn support_aliased_table_using_as() {
231+
assert_eq!(
232+
is_column_nullable("id", "foo", "select * from foo as f where id notnull"),
233+
Some(NullableResult::NotNull)
234+
);
235+
assert_eq!(
236+
is_column_nullable("id", "foo", "select * from foo as f where f.id notnull"),
237+
Some(NullableResult::NotNull)
238+
);
239+
assert_eq!(
240+
is_column_nullable("id", "foo", "select * from foo as f where foo.id notnull"),
241+
None
242+
);
243+
}
244+
245+
#[test]
246+
fn support_and() {
247+
assert_eq!(
248+
is_column_nullable("id", "foo", "select * from foo where 1=1 and id not null",),
249+
Some(NullableResult::NotNull)
250+
);
251+
}
252+
253+
#[test]
254+
fn support_or() {
255+
assert_eq!(
256+
is_column_nullable(
257+
"id",
258+
"foo",
259+
"select * from foo where id is not null or id is not null and 1=1",
260+
),
261+
Some(NullableResult::NotNull)
262+
);
263+
}
264+
265+
#[test]
266+
fn support_parens() {
267+
assert_eq!(
268+
is_column_nullable("id", "foo", "select * from foo f where (id not null)",),
269+
Some(NullableResult::NotNull)
270+
);
271+
assert_eq!(
272+
is_column_nullable(
273+
"id",
274+
"foo",
275+
"select * from foo f where (id is null) or (id is null)",
276+
),
277+
Some(NullableResult::Null)
278+
);
279+
assert_eq!(
280+
is_column_nullable(
281+
"id",
282+
"foo",
283+
"select * from foo f where (id is null) and (id is null)",
284+
),
285+
Some(NullableResult::Null)
286+
);
287+
}
288+
289+
#[test]
290+
fn support_is_null() {
291+
assert_eq!(
292+
is_column_nullable("id", "foo", "select * from foo f where id is null",),
293+
Some(NullableResult::Null)
294+
);
295+
}
296+
297+
#[test]
298+
fn support_join() {
299+
assert_eq!(
300+
is_column_nullable(
301+
"id",
302+
"bar",
303+
"select * from foo join bar where bar.id is null"
304+
),
305+
Some(NullableResult::Null)
306+
);
307+
assert_eq!(
308+
is_column_nullable(
309+
"id",
310+
"bar",
311+
"select * from foo join bar b where b.id is null"
312+
),
313+
Some(NullableResult::Null)
314+
);
315+
assert_eq!(
316+
is_column_nullable("id", "bar", "select * from foo, bar b where b.id is null"),
317+
Some(NullableResult::Null)
318+
);
319+
}
320+
321+
#[test]
322+
fn support_yoda_null_check() {
323+
assert_eq!(
324+
is_column_nullable("id", "foo", "select * from foo f where null is id",),
325+
Some(NullableResult::Null)
326+
);
327+
328+
assert_eq!(
329+
is_column_nullable("id", "foo", "select * from foo f where null is not id",),
330+
Some(NullableResult::NotNull)
331+
);
332+
}
333+
334+
#[test]
335+
fn support_in_list() {
336+
assert_eq!(
337+
is_column_nullable("id", "foo", "select * from foo f where id in (:bar)"),
338+
Some(NullableResult::NotNull)
339+
);
340+
}
341+
342+
#[test]
343+
fn support_constraints_on_join() {
344+
assert_eq!(
345+
is_column_nullable(
346+
"id",
347+
"foo",
348+
"SELECT foo.id FROM foo INNER JOIN bar ON bar.id = foo.id"
349+
),
350+
Some(NullableResult::NotNull)
351+
);
352+
353+
assert_eq!(
354+
is_column_nullable(
355+
"id",
356+
"bar",
357+
"SELECT foo.id FROM foo INNER JOIN bar ON bar.id = foo.id"
358+
),
359+
Some(NullableResult::NotNull)
360+
);
361+
}
362+
}

0 commit comments

Comments
 (0)