Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 396d5bd

Browse files
authored
Merge pull request #151 from datafold/fixes_jul5
Small Fixes
2 parents b83a72a + 6c7f05d commit 396d5bd

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

data_diff/databases/oracle.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .base import ThreadedDatabase, import_helper, ConnectError, QueryError
55
from .base import DEFAULT_DATETIME_PRECISION, DEFAULT_NUMERIC_PRECISION
66

7+
SESSION_TIME_ZONE = None # Changed by the tests
78

89
@import_helper("oracle")
910
def import_oracle():
@@ -34,7 +35,10 @@ def __init__(self, *, host, database, thread_count, **kw):
3435
def create_connection(self):
3536
self._oracle = import_oracle()
3637
try:
37-
return self._oracle.connect(**self.kwargs)
38+
c = self._oracle.connect(**self.kwargs)
39+
if SESSION_TIME_ZONE:
40+
c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'")
41+
return c
3842
except Exception as e:
3943
raise ConnectError(*e.args) from e
4044

data_diff/databases/postgresql.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .base import ThreadedDatabase, import_helper, ConnectError
33
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS
44

5+
SESSION_TIME_ZONE = None # Changed by the tests
56

67
@import_helper("postgresql")
78
def import_postgresql():
@@ -47,13 +48,17 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
4748
return super()._convert_db_precision_to_digits(p) - 2
4849

4950
def create_connection(self):
51+
if not self._args:
52+
self._args['host'] = None # psycopg2 requires 1+ arguments
53+
5054
pg = import_postgresql()
5155
try:
5256
c = pg.connect(**self._args)
53-
# c.cursor().execute("SET TIME ZONE 'UTC'")
57+
if SESSION_TIME_ZONE:
58+
c.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'")
5459
return c
5560
except pg.OperationalError as e:
56-
raise ConnectError(*e._args) from e
61+
raise ConnectError(*e.args) from e
5762

5863
def quote(self, s: str):
5964
return f'"{s}"'

tests/test_database_types.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from parameterized import parameterized
1313

1414
from data_diff import databases as db
15+
from data_diff.databases import postgresql, oracle
1516
from data_diff.utils import number_to_human
1617
from data_diff.diff_tables import TableDiffer, TableSegment, DEFAULT_BISECTION_THRESHOLD
1718
from .common import CONN_STRINGS, N_SAMPLES, N_THREADS, BENCHMARK, GIT_REVISION, random_table_suffix
@@ -20,6 +21,7 @@
2021
CONNS = {k: db.connect_to_uri(v, N_THREADS) for k, v in CONN_STRINGS.items()}
2122

2223
CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None)
24+
oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = 'UTC'
2325

2426

2527
class PaginatedTable:
@@ -434,6 +436,8 @@ def _insert_to_table(conn, table, values, type):
434436
value = str(sample)
435437
elif isinstance(sample, datetime) and isinstance(conn, (db.Presto, db.Oracle)):
436438
value = f"timestamp '{sample}'"
439+
elif isinstance(sample, datetime) and isinstance(conn, db.BigQuery) and type == 'datetime':
440+
value = f"cast(timestamp '{sample}' as datetime)"
437441
elif isinstance(sample, bytearray):
438442
value = f"'{sample.decode()}'"
439443
else:

0 commit comments

Comments
 (0)