1
+ import math
1
2
from typing import Dict , Sequence
2
3
import logging
3
4
13
14
ColType ,
14
15
UnknownColType ,
15
16
)
16
- from .base import MD5_HEXDIGITS , CHECKSUM_HEXDIGITS , BaseDialect , Database , import_helper , parse_table_name
17
+ from .base import MD5_HEXDIGITS , CHECKSUM_HEXDIGITS , BaseDialect , ThreadedDatabase , import_helper , parse_table_name
17
18
18
19
19
20
@import_helper (text = "You can install it using 'pip install databricks-sql-connector'" )
@@ -61,54 +62,57 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
61
62
return f"date_format({ value } , 'yyyy-MM-dd HH:mm:ss.{ precision_format } ')"
62
63
63
64
def normalize_number (self , value : str , coltype : NumericType ) -> str :
64
- return self .to_string (f"cast({ value } as decimal(38, { coltype .precision } ))" )
65
+ value = f"cast({ value } as decimal(38, { coltype .precision } ))"
66
+ if coltype .precision > 0 :
67
+ value = f"format_number({ value } , { coltype .precision } )"
68
+ return f"replace({ self .to_string (value )} , ',', '')"
65
69
66
70
def _convert_db_precision_to_digits (self , p : int ) -> int :
67
- # Subtracting 1 due to wierd precision issues
68
- return max (super ()._convert_db_precision_to_digits (p ) - 1 , 0 )
71
+ # Subtracting 2 due to wierd precision issues
72
+ return max (super ()._convert_db_precision_to_digits (p ) - 2 , 0 )
69
73
70
74
71
- class Databricks (Database ):
75
+ class Databricks (ThreadedDatabase ):
72
76
dialect = Dialect ()
73
77
74
- def __init__ (
75
- self ,
76
- http_path : str ,
77
- access_token : str ,
78
- server_hostname : str ,
79
- catalog : str = "hive_metastore" ,
80
- schema : str = "default" ,
81
- ** kwargs ,
82
- ):
83
- databricks = import_databricks ()
84
-
85
- self ._conn = databricks .sql .connect (
86
- server_hostname = server_hostname , http_path = http_path , access_token = access_token
87
- )
88
-
78
+ def __init__ (self , * , thread_count , ** kw ):
89
79
logging .getLogger ("databricks.sql" ).setLevel (logging .WARNING )
90
80
91
- self .catalog = catalog
92
- self .default_schema = schema
93
- self . kwargs = kwargs
81
+ self ._args = kw
82
+ self .default_schema = kw . get ( " schema" , "hive_metastore" )
83
+ super (). __init__ ( thread_count = thread_count )
94
84
95
- def _query (self , sql_code : str ) -> list :
96
- "Uses the standard SQL cursor interface"
97
- return self ._query_conn (self ._conn , sql_code )
85
+ def create_connection (self ):
86
+ databricks = import_databricks ()
87
+
88
+ try :
89
+ return databricks .sql .connect (
90
+ server_hostname = self ._args ["server_hostname" ],
91
+ http_path = self ._args ["http_path" ],
92
+ access_token = self ._args ["access_token" ],
93
+ catalog = self ._args ["catalog" ],
94
+ )
95
+ except databricks .sql .exc .Error as e :
96
+ raise ConnectionError (* e .args ) from e
98
97
99
98
def query_table_schema (self , path : DbPath ) -> Dict [str , tuple ]:
100
99
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
101
100
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
102
101
# So, to obtain information about schema, we should use another approach.
103
102
103
+ conn = self .create_connection ()
104
+
104
105
schema , table = self ._normalize_table_path (path )
105
- with self ._conn .cursor () as cursor :
106
- cursor .columns (catalog_name = self .catalog , schema_name = schema , table_name = table )
107
- rows = cursor .fetchall ()
106
+ with conn .cursor () as cursor :
107
+ cursor .columns (catalog_name = self ._args ["catalog" ], schema_name = schema , table_name = table )
108
+ try :
109
+ rows = cursor .fetchall ()
110
+ finally :
111
+ conn .close ()
108
112
if not rows :
109
113
raise RuntimeError (f"{ self .name } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
110
114
111
- d = {r .COLUMN_NAME : r for r in rows }
115
+ d = {r .COLUMN_NAME : ( r . COLUMN_NAME , r . TYPE_NAME , r . DECIMAL_DIGITS , None , None ) for r in rows }
112
116
assert len (d ) == len (rows )
113
117
return d
114
118
@@ -120,27 +124,26 @@ def _process_table_schema(
120
124
121
125
resulted_rows = []
122
126
for row in rows :
123
- row_type = "DECIMAL" if row . DATA_TYPE == 3 else row . TYPE_NAME
124
- type_cls = self .TYPE_CLASSES .get (row_type , UnknownColType )
127
+ row_type = "DECIMAL" if row [ 1 ]. startswith ( "DECIMAL" ) else row [ 1 ]
128
+ type_cls = self .dialect . TYPE_CLASSES .get (row_type , UnknownColType )
125
129
126
130
if issubclass (type_cls , Integer ):
127
- row = (row . COLUMN_NAME , row_type , None , None , 0 )
131
+ row = (row [ 0 ] , row_type , None , None , 0 )
128
132
129
133
elif issubclass (type_cls , Float ):
130
- numeric_precision = self . _convert_db_precision_to_digits (row . DECIMAL_DIGITS )
131
- row = (row . COLUMN_NAME , row_type , None , numeric_precision , None )
134
+ numeric_precision = math . ceil (row [ 2 ] / math . log ( 2 , 10 ) )
135
+ row = (row [ 0 ] , row_type , None , numeric_precision , None )
132
136
133
137
elif issubclass (type_cls , Decimal ):
134
- # TYPE_NAME has a format DECIMAL(x,y)
135
- items = row .TYPE_NAME [8 :].rstrip (")" ).split ("," )
138
+ items = row [1 ][8 :].rstrip (")" ).split ("," )
136
139
numeric_precision , numeric_scale = int (items [0 ]), int (items [1 ])
137
- row = (row . COLUMN_NAME , row_type , None , numeric_precision , numeric_scale )
140
+ row = (row [ 0 ] , row_type , None , numeric_precision , numeric_scale )
138
141
139
142
elif issubclass (type_cls , Timestamp ):
140
- row = (row . COLUMN_NAME , row_type , row . DECIMAL_DIGITS , None , None )
143
+ row = (row [ 0 ] , row_type , row [ 2 ] , None , None )
141
144
142
145
else :
143
- row = (row . COLUMN_NAME , row_type , None , None , None )
146
+ row = (row [ 0 ] , row_type , None , None , None )
144
147
145
148
resulted_rows .append (row )
146
149
@@ -153,9 +156,6 @@ def parse_table_name(self, name: str) -> DbPath:
153
156
path = parse_table_name (name )
154
157
return self ._normalize_table_path (path )
155
158
156
- def close (self ):
157
- self ._conn .close ()
158
-
159
159
@property
160
160
def is_autocommit (self ) -> bool :
161
161
return True
0 commit comments