Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 04c2917

Browse files
authored
Merge pull request #273 from pik94/fix-databricks
Fix databricks
2 parents e31df92 + 9c93229 commit 04c2917

File tree

1 file changed

+43
-43
lines changed

1 file changed

+43
-43
lines changed

data_diff/databases/databricks.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from typing import Dict, Sequence
23
import logging
34

@@ -13,7 +14,7 @@
1314
ColType,
1415
UnknownColType,
1516
)
16-
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, Database, import_helper, parse_table_name
17+
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, ThreadedDatabase, import_helper, parse_table_name
1718

1819

1920
@import_helper(text="You can install it using 'pip install databricks-sql-connector'")
@@ -61,54 +62,57 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
6162
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"
6263

6364
def normalize_number(self, value: str, coltype: NumericType) -> str:
64-
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
65+
value = f"cast({value} as decimal(38, {coltype.precision}))"
66+
if coltype.precision > 0:
67+
value = f"format_number({value}, {coltype.precision})"
68+
return f"replace({self.to_string(value)}, ',', '')"
6569

6670
def _convert_db_precision_to_digits(self, p: int) -> int:
67-
# Subtracting 1 due to wierd precision issues
68-
return max(super()._convert_db_precision_to_digits(p) - 1, 0)
71+
# Subtracting 2 due to wierd precision issues
72+
return max(super()._convert_db_precision_to_digits(p) - 2, 0)
6973

7074

71-
class Databricks(Database):
75+
class Databricks(ThreadedDatabase):
7276
dialect = Dialect()
7377

74-
def __init__(
75-
self,
76-
http_path: str,
77-
access_token: str,
78-
server_hostname: str,
79-
catalog: str = "hive_metastore",
80-
schema: str = "default",
81-
**kwargs,
82-
):
83-
databricks = import_databricks()
84-
85-
self._conn = databricks.sql.connect(
86-
server_hostname=server_hostname, http_path=http_path, access_token=access_token
87-
)
88-
78+
def __init__(self, *, thread_count, **kw):
8979
logging.getLogger("databricks.sql").setLevel(logging.WARNING)
9080

91-
self.catalog = catalog
92-
self.default_schema = schema
93-
self.kwargs = kwargs
81+
self._args = kw
82+
self.default_schema = kw.get("schema", "hive_metastore")
83+
super().__init__(thread_count=thread_count)
9484

95-
def _query(self, sql_code: str) -> list:
96-
"Uses the standard SQL cursor interface"
97-
return self._query_conn(self._conn, sql_code)
85+
def create_connection(self):
86+
databricks = import_databricks()
87+
88+
try:
89+
return databricks.sql.connect(
90+
server_hostname=self._args["server_hostname"],
91+
http_path=self._args["http_path"],
92+
access_token=self._args["access_token"],
93+
catalog=self._args["catalog"],
94+
)
95+
except databricks.sql.exc.Error as e:
96+
raise ConnectionError(*e.args) from e
9897

9998
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
10099
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
101100
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
102101
# So, to obtain information about schema, we should use another approach.
103102

103+
conn = self.create_connection()
104+
104105
schema, table = self._normalize_table_path(path)
105-
with self._conn.cursor() as cursor:
106-
cursor.columns(catalog_name=self.catalog, schema_name=schema, table_name=table)
107-
rows = cursor.fetchall()
106+
with conn.cursor() as cursor:
107+
cursor.columns(catalog_name=self._args["catalog"], schema_name=schema, table_name=table)
108+
try:
109+
rows = cursor.fetchall()
110+
finally:
111+
conn.close()
108112
if not rows:
109113
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
110114

111-
d = {r.COLUMN_NAME: r for r in rows}
115+
d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows}
112116
assert len(d) == len(rows)
113117
return d
114118

@@ -120,27 +124,26 @@ def _process_table_schema(
120124

121125
resulted_rows = []
122126
for row in rows:
123-
row_type = "DECIMAL" if row.DATA_TYPE == 3 else row.TYPE_NAME
124-
type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType)
127+
row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1]
128+
type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType)
125129

126130
if issubclass(type_cls, Integer):
127-
row = (row.COLUMN_NAME, row_type, None, None, 0)
131+
row = (row[0], row_type, None, None, 0)
128132

129133
elif issubclass(type_cls, Float):
130-
numeric_precision = self._convert_db_precision_to_digits(row.DECIMAL_DIGITS)
131-
row = (row.COLUMN_NAME, row_type, None, numeric_precision, None)
134+
numeric_precision = math.ceil(row[2] / math.log(2, 10))
135+
row = (row[0], row_type, None, numeric_precision, None)
132136

133137
elif issubclass(type_cls, Decimal):
134-
# TYPE_NAME has a format DECIMAL(x,y)
135-
items = row.TYPE_NAME[8:].rstrip(")").split(",")
138+
items = row[1][8:].rstrip(")").split(",")
136139
numeric_precision, numeric_scale = int(items[0]), int(items[1])
137-
row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale)
140+
row = (row[0], row_type, None, numeric_precision, numeric_scale)
138141

139142
elif issubclass(type_cls, Timestamp):
140-
row = (row.COLUMN_NAME, row_type, row.DECIMAL_DIGITS, None, None)
143+
row = (row[0], row_type, row[2], None, None)
141144

142145
else:
143-
row = (row.COLUMN_NAME, row_type, None, None, None)
146+
row = (row[0], row_type, None, None, None)
144147

145148
resulted_rows.append(row)
146149

@@ -153,9 +156,6 @@ def parse_table_name(self, name: str) -> DbPath:
153156
path = parse_table_name(name)
154157
return self._normalize_table_path(path)
155158

156-
def close(self):
157-
self._conn.close()
158-
159159
@property
160160
def is_autocommit(self) -> bool:
161161
return True

0 commit comments

Comments
 (0)