Skip to content

Commit 8149282

Browse files
authored
Merge pull request #232 from ydb-platform/fix-sqlalchemy-nullable
Fix sqlalchemy nullable
2 parents dd65645 + dc90176 commit 8149282

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

examples/_sqlalchemy_example/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def run_example_core(engine):
196196
def main():
197197
parser = argparse.ArgumentParser(
198198
formatter_class=argparse.RawDescriptionHelpFormatter,
199-
description="""\033[92mYandex.Database examples _sqlalchemy usage.\x1b[0m\n""",
199+
description="""\033[92mYandex.Database examples sqlalchemy usage.\x1b[0m\n""",
200200
)
201201
parser.add_argument(
202202
"-d",

tests/_sqlalchemy/_test_inspect.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import ydb
2+
3+
import sqlalchemy as sa
4+
5+
6+
def test_get_columns(driver_sync, engine):
7+
session = ydb.retry_operation_sync(
8+
lambda: driver_sync.table_client.session().create()
9+
)
10+
session.execute_scheme(
11+
"CREATE TABLE test(id Int64 NOT NULL, value TEXT, num DECIMAL(22, 9), PRIMARY KEY (id))"
12+
)
13+
inspect = sa.inspect(engine)
14+
columns = inspect.get_columns("test")
15+
for c in columns:
16+
c["type"] = type(c["type"])
17+
18+
assert columns == [
19+
{"name": "id", "type": sa.INTEGER, "nullable": False},
20+
{"name": "value", "type": sa.TEXT, "nullable": True},
21+
{"name": "num", "type": sa.DECIMAL, "nullable": True},
22+
]
23+
24+
session.execute_scheme("DROP TABLE test")

ydb/_sqlalchemy/__init__.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,16 @@ def upsert(table):
206206
}
207207

208208

209-
def _get_column_type(t):
209+
def _get_column_info(t):
210+
nullable = False
210211
if isinstance(t, ydb.OptionalType):
212+
nullable = True
211213
t = t.item
212214

213215
if isinstance(t, ydb.DecimalType):
214-
return sa.DECIMAL(precision=t.item.precision, scale=t.item.scale)
216+
return sa.DECIMAL(precision=t.precision, scale=t.scale), nullable
215217

216-
return COLUMN_TYPES[t]
218+
return COLUMN_TYPES[t], nullable
217219

218220

219221
class YqlDialect(DefaultDialect):
@@ -268,11 +270,12 @@ def get_columns(self, connection, table_name, schema=None, **kw):
268270
columns = raw_conn.describe(qt)
269271
as_compatible = []
270272
for column in columns:
273+
col_type, nullable = _get_column_info(column.type)
271274
as_compatible.append(
272275
{
273276
"name": column.name,
274-
"type": _get_column_type(column.type),
275-
"nullable": True,
277+
"type": col_type,
278+
"nullable": nullable,
276279
}
277280
)
278281

0 commit comments

Comments
 (0)