Skip to content

Commit b41e3e2

Browse files
authored
Merge pull request #267 from ydb-platform/sqlalchemy-fix
sqlalchemy: validate identifiers to prevent injection
2 parents 5bad7e2 + 50de0d3 commit b41e3e2

File tree

2 files changed

+72
-6
lines changed

2 files changed

+72
-6
lines changed

ydb/_dbapi/cursor.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import uuid
55
import decimal
6+
import string
67

78
import ydb
89
from .errors import DatabaseError, ProgrammingError
@@ -11,6 +12,17 @@
1112
logger = logging.getLogger(__name__)
1213

1314

15+
identifier_starts = {x for x in itertools.chain(string.ascii_letters, "_")}
16+
valid_identifier_chars = {x for x in itertools.chain(identifier_starts, string.digits)}
17+
18+
19+
def check_identifier_valid(idt: str):
20+
valid = idt and idt[0] in identifier_starts and all(c in valid_identifier_chars for c in idt)
21+
if not valid:
22+
raise ProgrammingError(f"Invalid identifier {idt}")
23+
return valid
24+
25+
1426
def get_column_type(type_obj):
1527
return str(ydb.convert.type_to_native(type_obj))
1628

@@ -48,7 +60,7 @@ def _generate_type_str(value):
4860
stype = f"Set<{nested_type}>"
4961

5062
if stype is None:
51-
raise ProgrammingError("Cannot translate python type to ydb type.", tvalue, value)
63+
raise ProgrammingError(f"Cannot translate value {value} (type {tvalue}) to ydb type.")
5264

5365
return stype
5466

@@ -70,6 +82,8 @@ def execute(self, sql, parameters=None, context=None):
7082
sql_params = None
7183

7284
if parameters:
85+
for name in parameters.keys():
86+
check_identifier_valid(name)
7387
sql = sql % {k: f"${k}" for k, v in parameters.items()}
7488
sql_params = {f"${k}": v for k, v in parameters.items()}
7589
declare_stms = _generate_declare_stms(sql_params)
@@ -137,13 +151,10 @@ def executescript(self, script):
137151
return self.execute(script)
138152

139153
def fetchone(self):
140-
if self.rows is None:
141-
return None
142-
return next(self.rows, None)
154+
return next(self.rows or [], None)
143155

144156
def fetchmany(self, size=None):
145-
size = self.arraysize if size is None else size
146-
return list(itertools.islice(self.rows, size))
157+
return list(itertools.islice(self.rows, size or self.arraysize))
147158

148159
def fetchall(self):
149160
return list(self.rows)

ydb/_dbapi/test_cursor.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
import uuid
3+
import decimal
4+
from datetime import date, datetime, timedelta
5+
6+
from .cursor import _generate_type_str, check_identifier_valid, ProgrammingError
7+
8+
9+
def test_check_identifier_valid():
10+
assert check_identifier_valid("id")
11+
assert check_identifier_valid("_id")
12+
assert check_identifier_valid("id0")
13+
assert check_identifier_valid("foo_bar")
14+
assert check_identifier_valid("foo_bar_1")
15+
16+
with pytest.raises(ProgrammingError):
17+
check_identifier_valid("")
18+
19+
with pytest.raises(ProgrammingError):
20+
check_identifier_valid("01")
21+
22+
with pytest.raises(ProgrammingError):
23+
check_identifier_valid("(a)")
24+
25+
with pytest.raises(ProgrammingError):
26+
check_identifier_valid("drop table")
27+
28+
29+
def test_generate_type_str():
30+
assert _generate_type_str(True) == "Bool"
31+
assert _generate_type_str(1) == "Int64"
32+
assert _generate_type_str("foo") == "Utf8"
33+
assert _generate_type_str(b"foo") == "String"
34+
assert _generate_type_str(3.1415) == "Double"
35+
assert _generate_type_str(uuid.uuid4()) == "Uuid"
36+
assert _generate_type_str(decimal.Decimal("3.1415926535")) == "Decimal(22, 9)"
37+
38+
assert _generate_type_str([1, 2, 3]) == "List<Int64>"
39+
assert _generate_type_str((1, "2", False)) == "Tuple<Int64, Utf8, Bool>"
40+
assert _generate_type_str({1, 2, 3}) == "Set<Int64>"
41+
assert _generate_type_str({"foo": 1, "bar": 2, "kek": 3.14}) == "Struct<foo: Int64, bar: Int64, kek: Double>"
42+
43+
assert _generate_type_str([[1], [2], [3]]) == "List<List<Int64>>"
44+
assert _generate_type_str([{"a": 1, "b": 2}, {"a": 11, "b": 22}]) == "List<Struct<a: Int64, b: Int64>>"
45+
assert _generate_type_str(("foo", [1], 3.14)) == "Tuple<Utf8, List<Int64>, Double>"
46+
47+
assert _generate_type_str(datetime.now()) == "Timestamp"
48+
assert _generate_type_str(date.today()) == "Date"
49+
assert _generate_type_str(timedelta(days=2)) == "Interval"
50+
51+
with pytest.raises(ProgrammingError):
52+
assert _generate_type_str(None)
53+
54+
with pytest.raises(ProgrammingError):
55+
assert _generate_type_str(object())

0 commit comments

Comments
 (0)