Skip to content

Commit ca59e4c

Browse files
authored
Merge branch 'main' into add-server-codes
2 parents 7618be4 + e19ad3c commit ca59e4c

File tree

4 files changed

+55
-6
lines changed

4 files changed

+55
-6
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
## 2.13.4 ##
2+
* fixed sqlalchemy get_columns method with not null columns
3+
14
## 2.13.3 ##
25
* fixed use transaction object when commit with flag
36

tests/sqlalchemy/conftest.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pytest
2+
import sqlalchemy as sa
3+
4+
from ydb.sqlalchemy import register_dialect
5+
6+
7+
@pytest.fixture
8+
def sa_engine(endpoint, database):
9+
register_dialect()
10+
engine = sa.create_engine(
11+
"yql:///ydb/",
12+
connect_args={"database": database, "endpoint": endpoint},
13+
)
14+
15+
yield engine
16+
engine.dispose()

tests/sqlalchemy/test_inspect.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import ydb
2+
3+
import sqlalchemy as sa
4+
5+
6+
def test_get_columns(driver_sync, sa_engine):
7+
session = ydb.retry_operation_sync(
8+
lambda: driver_sync.table_client.session().create()
9+
)
10+
session.execute_scheme(
11+
"CREATE TABLE test(id Int64 NOT NULL, value TEXT, num DECIMAL(22, 9), PRIMARY KEY (id))"
12+
)
13+
inspect = sa.inspect(sa_engine)
14+
columns = inspect.get_columns("test")
15+
for c in columns:
16+
c["type"] = type(c["type"])
17+
18+
assert columns == [
19+
{"name": "id", "type": sa.INTEGER, "nullable": False},
20+
{"name": "value", "type": sa.TEXT, "nullable": True},
21+
{"name": "num", "type": sa.DECIMAL, "nullable": True},
22+
]
23+
24+
session.execute_scheme("DROP TABLE test")

ydb/sqlalchemy/__init__.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,16 @@ def visit_function(self, func, add_to_result_map=None, **kwargs):
191191
ydb.PrimitiveType.DyNumber: sa.TEXT,
192192
}
193193

194-
def _get_column_type(t):
195-
if isinstance(t.item, ydb.DecimalType):
196-
return sa.DECIMAL(precision=t.item.precision, scale=t.item.scale)
194+
def _get_column_info(t):
195+
nullable = False
196+
if isinstance(t, ydb.OptionalType):
197+
nullable = True
198+
t = t.item
197199

198-
return COLUMN_TYPES[t.item]
200+
if isinstance(t, ydb.DecimalType):
201+
return sa.DECIMAL(precision=t.precision, scale=t.scale), nullable
202+
203+
return COLUMN_TYPES[t], nullable
199204

200205
class YqlDialect(DefaultDialect):
201206
name = "yql"
@@ -250,11 +255,12 @@ def get_columns(self, connection, table_name, schema=None, **kw):
250255
columns = raw_conn.describe(qt)
251256
as_compatible = []
252257
for column in columns:
258+
col_type, nullable = _get_column_info(column.type)
253259
as_compatible.append(
254260
{
255261
"name": column.name,
256-
"type": _get_column_type(column.type),
257-
"nullable": True,
262+
"type": col_type,
263+
"nullable": nullable,
258264
}
259265
)
260266

0 commit comments

Comments
 (0)