26
26
from data_diff .databases .mssql import MsSQL
27
27
28
28
29
- @attrs .define ( frozen = True )
29
+ @attrs .frozen
30
30
class MatchUriPath :
31
31
database_cls : Type [Database ]
32
32
@@ -98,13 +98,11 @@ class Connect:
98
98
"""Provides methods for connecting to a supported database using a URL or connection dict."""
99
99
100
100
database_by_scheme : Dict [str , Database ]
101
- match_uri_path : Dict [str , MatchUriPath ]
102
101
conn_cache : MutableMapping [Hashable , Database ]
103
102
104
103
def __init__ (self , database_by_scheme : Dict [str , Database ] = DATABASE_BY_SCHEME ):
105
104
super ().__init__ ()
106
105
self .database_by_scheme = database_by_scheme
107
- self .match_uri_path = {name : MatchUriPath (cls ) for name , cls in database_by_scheme .items ()}
108
106
self .conn_cache = weakref .WeakValueDictionary ()
109
107
110
108
def for_databases (self , * dbs ) -> Self :
@@ -157,12 +155,10 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs)
157
155
return self .connect_with_dict (conn_dict , thread_count , ** kwargs )
158
156
159
157
try :
160
- matcher = self .match_uri_path [scheme ]
158
+ cls = self .database_by_scheme [scheme ]
161
159
except KeyError :
162
160
raise NotImplementedError (f"Scheme '{ scheme } ' currently not supported" )
163
161
164
- cls = matcher .database_cls
165
-
166
162
if scheme == "databricks" :
167
163
assert not dsn .user
168
164
kw = {}
@@ -175,6 +171,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs)
175
171
kw ["filepath" ] = dsn .dbname
176
172
kw ["dbname" ] = dsn .user
177
173
else :
174
+ matcher = MatchUriPath (cls )
178
175
kw = matcher .match_path (dsn )
179
176
180
177
if scheme == "bigquery" :
@@ -198,7 +195,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs)
198
195
199
196
kw = {k : v for k , v in kw .items () if v is not None }
200
197
201
- if issubclass (cls , ThreadedDatabase ):
198
+ if isinstance ( cls , type ) and issubclass (cls , ThreadedDatabase ):
202
199
db = cls (thread_count = thread_count , ** kw , ** kwargs )
203
200
else :
204
201
db = cls (** kw , ** kwargs )
@@ -209,11 +206,10 @@ def connect_with_dict(self, d, thread_count, **kwargs):
209
206
d = dict (d )
210
207
driver = d .pop ("driver" )
211
208
try :
212
- matcher = self .match_uri_path [driver ]
209
+ cls = self .database_by_scheme [driver ]
213
210
except KeyError :
214
211
raise NotImplementedError (f"Driver '{ driver } ' currently not supported" )
215
212
216
- cls = matcher .database_cls
217
213
if issubclass (cls , ThreadedDatabase ):
218
214
db = cls (thread_count = thread_count , ** d , ** kwargs )
219
215
else :
0 commit comments