@@ -55,7 +55,9 @@ pub fn is_column_nullable(column: &str, table_name: &str, query: &str) -> Option
55
55
fn get_used_table_name < ' a > ( table_name : & str , from : & ' a ast:: FromClause ) -> Option < & ' a str > {
56
56
if let Some ( table) = & from. select {
57
57
match table. as_ref ( ) {
58
- ast:: SelectTable :: Table ( name, as_name, _) if name. name . 0 == table_name => {
58
+ ast:: SelectTable :: Table ( name, as_name, _)
59
+ if compare_identifier ( & name. name . 0 , table_name) =>
60
+ {
59
61
let used_table_name = match as_name {
60
62
Some ( ast:: As :: As ( name) ) => & name. 0 ,
61
63
Some ( ast:: As :: Elided ( name) ) => & name. 0 ,
@@ -80,7 +82,7 @@ fn get_used_table_name<'a>(table_name: &str, from: &'a ast::FromClause) -> Optio
80
82
| ast:: JoinOperator :: TypedJoin ( Some ( ast:: JoinType :: CROSS ) ) ,
81
83
table : ast:: SelectTable :: Table ( name, as_name, _) ,
82
84
..
83
- } if name. name . 0 == table_name => {
85
+ } if compare_identifier ( & name. name . 0 , table_name) => {
84
86
let used_table_name = match as_name {
85
87
Some ( ast:: As :: As ( name) ) => & name. 0 ,
86
88
Some ( ast:: As :: Elided ( name) ) => & name. 0 ,
@@ -190,10 +192,22 @@ fn test_expr(column_name: &str, table_name: &str, expr: &ast::Expr) -> Option<Nu
190
192
None
191
193
}
192
194
195
+ fn compare_identifier ( lhs : & str , rhs : & str ) -> bool {
196
+ lhs. trim_start_matches ( "\" " )
197
+ . trim_end_matches ( "\" " )
198
+ . to_lowercase ( )
199
+ == rhs
200
+ . trim_start_matches ( "\" " )
201
+ . trim_end_matches ( "\" " )
202
+ . to_lowercase ( )
203
+ }
204
+
193
205
fn expr_matches_name ( column_name : & str , table_name : & str , expr : & ast:: Expr ) -> bool {
194
206
match expr {
195
- ast:: Expr :: Id ( id) => id. 0 == column_name,
196
- ast:: Expr :: Qualified ( name, id) => name. 0 == table_name && id. 0 == column_name,
207
+ ast:: Expr :: Id ( id) => compare_identifier ( & id. 0 , column_name) ,
208
+ ast:: Expr :: Qualified ( name, id) => {
209
+ compare_identifier ( & name. 0 , table_name) && compare_identifier ( & id. 0 , column_name)
210
+ }
197
211
_ => false ,
198
212
}
199
213
}
@@ -388,4 +402,36 @@ mod tests {
388
402
Some ( NullableResult :: NotNull )
389
403
) ;
390
404
}
405
+
406
+ #[ test]
407
+ fn support_quoted_names ( ) {
408
+ assert_eq ! (
409
+ is_column_nullable( "id" , "foo" , "SELECT id FROM \" foo\" WHERE id = ?" ) ,
410
+ Some ( NullableResult :: NotNull )
411
+ ) ;
412
+ assert_eq ! (
413
+ is_column_nullable( "id" , "foo" , "SELECT id FROM \" foo\" WHERE \" foo\" .id = ?" ) ,
414
+ Some ( NullableResult :: NotNull )
415
+ ) ;
416
+ assert_eq ! (
417
+ is_column_nullable( "id" , "foo" , "SELECT id FROM foo WHERE \" id\" = ?" ) ,
418
+ Some ( NullableResult :: NotNull )
419
+ ) ;
420
+ }
421
+
422
+ #[ test]
423
+ fn support_uppercase_names ( ) {
424
+ assert_eq ! (
425
+ is_column_nullable( "id" , "foo" , "SELECT id FROM FOO WHERE id = ?" ) ,
426
+ Some ( NullableResult :: NotNull )
427
+ ) ;
428
+ assert_eq ! (
429
+ is_column_nullable( "id" , "bar" , "SELECT id FROM foo, BAR WHERE bar.id = ?" ) ,
430
+ Some ( NullableResult :: NotNull )
431
+ ) ;
432
+ assert_eq ! (
433
+ is_column_nullable( "id" , "bar" , "SELECT id FROM foo, bar WHERE ID = ?" ) ,
434
+ Some ( NullableResult :: NotNull )
435
+ ) ;
436
+ }
391
437
}
0 commit comments