Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e884eb2

Browse files
authored
Merge pull request #336 from datafold/fix_current_timestamp
Added tests for current_timestamp + fixes for some of the dbs
2 parents c6e08bd + 4ac3e9f commit e884eb2

File tree

10 files changed

+37
-2
lines changed

10 files changed

+37
-2
lines changed

data_diff/sqeleton/databases/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ class BaseDialect(AbstractDialect):
124124
SUPPORTS_INDEXES = False
125125
TYPE_CLASSES: Dict[str, type] = {}
126126

127+
PLACEHOLDER_TABLE = None # Used for Oracle
128+
127129
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
128130
if offset:
129131
raise NotImplementedError("No support for OFFSET in query")
@@ -302,7 +304,9 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list):
302304
return int(res)
303305
elif res_type is datetime:
304306
res = _one(_one(res))
305-
return res # XXX parse timestamp?
307+
if isinstance(res, str):
308+
res = datetime.fromisoformat(res[:23]) # TODO use a better parsing method
309+
return res
306310
elif res_type is tuple:
307311
assert len(res) == 1, (sql_code, res)
308312
return res[0]

data_diff/sqeleton/databases/clickhouse.py

+3
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
155155
def set_timezone_to_utc(self) -> str:
156156
raise NotImplementedError()
157157

158+
def current_timestamp(self) -> str:
159+
return "now()"
160+
158161

159162
class Clickhouse(ThreadedDatabase):
160163
dialect = Dialect()

data_diff/sqeleton/databases/duckdb.py

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def parse_type(
111111
def set_timezone_to_utc(self) -> str:
112112
return "SET GLOBAL TimeZone='UTC'"
113113

114+
def current_timestamp(self) -> str:
115+
return "current_timestamp"
114116

115117
class DuckDB(Database):
116118
dialect = Dialect()

data_diff/sqeleton/databases/oracle.py

+4
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class Dialect(BaseDialect, Mixin_Schema):
8686
"VARCHAR2": Text,
8787
}
8888
ROUNDS_ON_PREC_LOSS = True
89+
PLACEHOLDER_TABLE = "DUAL"
8990

9091
def quote(self, s: str):
9192
return f'"{s}"'
@@ -152,6 +153,9 @@ def parse_type(
152153
def set_timezone_to_utc(self) -> str:
153154
return "ALTER SESSION SET TIME_ZONE = 'UTC'"
154155

156+
def current_timestamp(self) -> str:
157+
return "LOCALTIMESTAMP"
158+
155159

156160
class Oracle(ThreadedDatabase):
157161
dialect = Dialect()

data_diff/sqeleton/databases/postgresql.py

+3
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
9090
def set_timezone_to_utc(self) -> str:
9191
return "SET TIME ZONE 'UTC'"
9292

93+
def current_timestamp(self) -> str:
94+
return "current_timestamp"
95+
9396

9497
class PostgreSQL(ThreadedDatabase):
9598
dialect = PostgresqlDialect()

data_diff/sqeleton/databases/presto.py

+2
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def parse_type(
137137
def set_timezone_to_utc(self) -> str:
138138
return "SET TIME ZONE '+00:00'"
139139

140+
def current_timestamp(self) -> str:
141+
return "current_timestamp"
140142

141143
class Presto(Database):
142144
dialect = Dialect()

data_diff/sqeleton/databases/vertica.py

+3
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def parse_type(
144144
def set_timezone_to_utc(self) -> str:
145145
return "SET TIME ZONE TO 'UTC'"
146146

147+
def current_timestamp(self) -> str:
148+
return "current_timestamp(6)"
149+
147150

148151
class Vertica(ThreadedDatabase):
149152
dialect = Dialect()

data_diff/sqeleton/queries/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
or_,
1919
leftjoin,
2020
rightjoin,
21+
current_timestamp
2122
)
2223
from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In, Code, Column
2324
from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString

data_diff/sqeleton/queries/ast_classes.py

+3
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,9 @@ def compile(self, parent_c: Compiler) -> str:
601601

602602
if self.table:
603603
select += " FROM " + c.compile(self.table)
604+
elif c.dialect.PLACEHOLDER_TABLE:
605+
select += f" FROM {c.dialect.PLACEHOLDER_TABLE}"
606+
604607

605608
if self.where_exprs:
606609
select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs))

tests/sqeleton/test_database.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import Callable, List
2+
from datetime import datetime
23
import unittest
34

45
from ..common import str_to_checksum, TEST_MYSQL_CONN_STRING
56
from ..common import str_to_checksum, test_each_database_in_list, DiffTestCase, get_conn, random_table_suffix
67

7-
from data_diff.sqeleton.queries import table
8+
from data_diff.sqeleton.queries import table, current_timestamp
89

910
from data_diff import databases as dbs
1011
from data_diff.databases import connect
@@ -63,3 +64,12 @@ def test_table_list(self):
6364

6465
db.query(tbl.drop())
6566
assert not db.query(q)
67+
68+
69+
@test_each_database
70+
class TestQueries(unittest.TestCase):
71+
72+
def test_current_timestamp(self):
73+
db = get_conn(self.db_cls)
74+
res = db.query(current_timestamp(), datetime)
75+
assert isinstance(res, datetime), (res, type(res))

0 commit comments

Comments
 (0)