Skip to content

Commit ab7370b

Browse files
committed
feat: Add support for parameterized queries with '?'
1 parent 641de7f commit ab7370b

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

mindsdb_sql_parser/ast/insert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def to_tree(self, *args, level=0, **kwargs):
7979
def get_string(self, *args, **kwargs):
8080
if self.columns is not None:
8181
cols = ', '.join([i.name for i in self.columns])
82-
columns_str = f'({cols})'
82+
columns_str = f' ({cols})'
8383
else:
8484
columns_str = ''
8585

@@ -97,4 +97,4 @@ def get_string(self, *args, **kwargs):
9797
else:
9898
from_select_str = ''
9999

100-
return f'INSERT INTO {str(self.table)}{columns_str} {values_str}{from_select_str}'
100+
return f'INSERT INTO {str(self.table)}{columns_str} {values_str}{from_select_str}'

mindsdb_sql_parser/ast/select/parameter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ def to_tree(self, *args, level=0, **kwargs):
1212
return indent(level) + f'Parameter(value={repr(self.value)}{alias_str})'
1313

1414
def get_string(self, *args, **kwargs):
15-
return str(self.value)
15+
if self.value == '?':
16+
return self.value
17+
return ':' + str(self.value)
1618

1719
def __repr__(self):
1820
return f'Parameter({repr(self.value)})'

tests/test_base_sql/test_insert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class TestInsert:
88

99
def test_insert(self):
10-
sql = "INSERT INTO tbl_name(a, c) VALUES (1, 3), (4, 5)"
10+
sql = "INSERT INTO tbl_name (a, c) VALUES (1, 3), (4, 5)"
1111

1212
ast = parse_sql(sql)
1313
expected_ast = Insert(
@@ -37,7 +37,7 @@ def test_insert_no_columns(self):
3737
assert ast.to_tree() == expected_ast.to_tree()
3838

3939
def test_insert_from_select(self):
40-
sql = "INSERT INTO tbl_name(a, c) SELECT b, d from table2"
40+
sql = "INSERT INTO tbl_name (a, c) SELECT b, d from table2"
4141

4242
ast = parse_sql(sql)
4343
expected_ast = Insert(
@@ -78,7 +78,7 @@ class TestInsertMDB:
7878
def test_insert_from_union(self):
7979
from textwrap import dedent
8080
sql = dedent("""
81-
INSERT INTO tbl_name(a, c) SELECT * from table1
81+
INSERT INTO tbl_name (a, c) SELECT * from table1
8282
UNION
8383
SELECT * from table2""")[1:]
8484

tests/test_base_sql/test_parameters.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_select_with_parameter_in_where(self):
1414
])
1515
)
1616
assert ast.to_tree() == expected_ast.to_tree()
17-
assert str(ast) == str(expected_ast)
17+
assert str(ast) == sql
1818

1919
def test_select_multiple_parameters(self):
2020
sql = "SELECT * FROM tbl WHERE col1 > ? AND col2 = ?"
@@ -34,10 +34,10 @@ def test_select_multiple_parameters(self):
3434
])
3535
)
3636
assert ast.to_tree() == expected_ast.to_tree()
37-
assert str(ast) == str(expected_ast)
37+
assert str(ast) == sql
3838

3939
def test_insert_with_parameters(self):
40-
sql = "INSERT INTO tbl_name(a, c) VALUES (?, ?)"
40+
sql = "INSERT INTO tbl_name (a, c) VALUES (?, ?)"
4141
ast = parse_sql(sql)
4242
expected_ast = Insert(
4343
table=Identifier('tbl_name'),
@@ -47,7 +47,7 @@ def test_insert_with_parameters(self):
4747
]
4848
)
4949
assert ast.to_tree() == expected_ast.to_tree()
50-
assert str(ast) == str(expected_ast)
50+
assert str(ast) == sql
5151

5252
def test_insert_with_multiple_parameter_rows(self):
5353
sql = "INSERT INTO tbl_name VALUES (?, ?), (?, ?)"
@@ -60,7 +60,7 @@ def test_insert_with_multiple_parameter_rows(self):
6060
]
6161
)
6262
assert ast.to_tree() == expected_ast.to_tree()
63-
assert str(ast) == str(expected_ast)
63+
assert str(ast) == sql
6464

6565
def test_select_parameter_as_target(self):
6666
sql = "SELECT ?"
@@ -69,4 +69,4 @@ def test_select_parameter_as_target(self):
6969
targets=[Parameter('?')]
7070
)
7171
assert ast.to_tree() == expected_ast.to_tree()
72-
assert str(ast) == str(expected_ast)
72+
assert str(ast) == sql

0 commit comments

Comments
 (0)