Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 92f5ed4

Browse files
committed
Cleanup database params - just pass the dict
1 parent d0ba8a5 commit 92f5ed4

File tree

7 files changed

+36
-44
lines changed

7 files changed

+36
-44
lines changed

data_diff/databases/connect.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,22 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
116116
kw = matcher.match_path(dsn)
117117

118118
if scheme == "bigquery":
119-
return cls(dsn.host, **kw)
119+
kw["project"] = dsn.host
120+
return cls(**kw)
121+
122+
if scheme == "snowflake":
123+
kw["account"] = dsn.host
124+
assert not dsn.port
125+
kw["user"] = dsn.user
126+
kw["password"] = dsn.password
127+
else:
128+
kw["host"] = dsn.host
129+
kw["port"] = dsn.port
130+
kw["user"] = dsn.user
131+
kw["password"] = dsn.password
132+
kw = {k: v for k, v in kw.items() if v is not None}
120133

121134
if issubclass(cls, ThreadedDatabase):
122-
return cls(dsn.host, dsn.port, dsn.user, dsn.password, thread_count=thread_count, **kw)
135+
return cls(thread_count=thread_count, **kw)
123136

124-
return cls(dsn.host, dsn.port, dsn.user, dsn.password, **kw)
137+
return cls(**kw)

data_diff/databases/database_types.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ class ColType:
1818
pass
1919

2020

21+
class IKey(ABC):
22+
"Interface for ColType, for using a column as a key in data-diff"
23+
python_type: type
24+
25+
2126
@dataclass
2227
class PrecisionType(ColType):
2328
precision: int
@@ -54,7 +59,7 @@ class Float(FractionalType):
5459
pass
5560

5661

57-
class Decimal(FractionalType):
62+
class Decimal(FractionalType, IKey):
5863
@property
5964
def python_type(self) -> type:
6065
if self.precision == 0:
@@ -66,11 +71,6 @@ class StringType(ColType):
6671
pass
6772

6873

69-
class IKey(ABC):
70-
"Interface for ColType, for using a column as a key in data-diff"
71-
python_type: type
72-
73-
7474
class ColType_UUID(StringType, IKey):
7575
python_type = ArithUUID
7676

data_diff/databases/mysql.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@ class MySQL(ThreadedDatabase):
2929
}
3030
ROUNDS_ON_PREC_LOSS = True
3131

32-
def __init__(self, host, port, user, password, *, database, thread_count, **kw):
33-
args = dict(host=host, port=port, database=database, user=user, password=password, **kw)
34-
self._args = {k: v for k, v in args.items() if v is not None}
32+
def __init__(self, *, thread_count, **kw):
33+
self._args = kw
3534

3635
super().__init__(thread_count=thread_count)
3736

3837
# In MySQL schema and database are synonymous
39-
self.default_schema = database
38+
self.default_schema = kw["database"]
4039

4140
def create_connection(self):
4241
mysql = import_mysql()
@@ -48,7 +47,7 @@ def create_connection(self):
4847
elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR:
4948
raise ConnectError("Database does not exist") from e
5049
else:
51-
raise ConnectError(*e.args) from e
50+
raise ConnectError(*e._args) from e
5251

5352
def quote(self, s: str):
5453
return f"`{s}`"

data_diff/databases/oracle.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@ class Oracle(ThreadedDatabase):
2424
}
2525
ROUNDS_ON_PREC_LOSS = True
2626

27-
def __init__(self, host, port, user, password, *, database, thread_count, **kw):
28-
assert not port
29-
self.kwargs = dict(user=user, password=password, dsn="%s/%s" % (host, database), **kw)
27+
def __init__(self, *, host, database, thread_count, **kw):
28+
self.kwargs = dict(dsn="%s/%s" % (host, database), **kw)
3029

31-
self.default_schema = user
30+
self.default_schema = kw.get('user')
3231

3332
super().__init__(thread_count=thread_count)
3433

data_diff/databases/postgresql.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class PostgreSQL(ThreadedDatabase):
3535

3636
default_schema = "public"
3737

38-
def __init__(self, host, port, user, password, *, database, thread_count, **kw):
39-
self.args = dict(host=host, port=port, database=database, user=user, password=password, **kw)
38+
def __init__(self, *, thread_count, **kw):
39+
self._args = kw
4040

4141
super().__init__(thread_count=thread_count)
4242

@@ -47,11 +47,11 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
4747
def create_connection(self):
4848
pg = import_postgresql()
4949
try:
50-
c = pg.connect(**self.args)
50+
c = pg.connect(**self._args)
5151
# c.cursor().execute("SET TIME ZONE 'UTC'")
5252
return c
5353
except pg.OperationalError as e:
54-
raise ConnectError(*e.args) from e
54+
raise ConnectError(*e._args) from e
5555

5656
def quote(self, s: str):
5757
return f'"{s}"'

data_diff/databases/presto.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ class Presto(Database):
3535
}
3636
ROUNDS_ON_PREC_LOSS = True
3737

38-
def __init__(self, host, port, user, password, *, catalog, schema=None, **kw):
38+
def __init__(self, **kw):
3939
prestodb = import_presto()
40-
self.args = dict(host=host, user=user, catalog=catalog, schema=schema, **kw)
4140

42-
self._conn = prestodb.dbapi.connect(**self.args)
41+
self._conn = prestodb.dbapi.connect(**kw)
4342

4443
def quote(self, s: str):
4544
return f'"{s}"'

data_diff/databases/snowflake.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,7 @@ class Snowflake(Database):
2525
}
2626
ROUNDS_ON_PREC_LOSS = False
2727

28-
def __init__(
29-
self,
30-
account: str,
31-
_port: int,
32-
user: str,
33-
password: str,
34-
*,
35-
warehouse: str,
36-
schema: str,
37-
database: str,
38-
role: str = None,
39-
**kw,
40-
):
28+
def __init__(self, *, schema: str, **kw):
4129
snowflake = import_snowflake()
4230
logging.getLogger("snowflake.connector").setLevel(logging.WARNING)
4331

@@ -48,12 +36,6 @@ def __init__(
4836

4937
assert '"' not in schema, "Schema name should not contain quotes!"
5038
self._conn = snowflake.connector.connect(
51-
user=user,
52-
password=password,
53-
account=account,
54-
role=role,
55-
database=database,
56-
warehouse=warehouse,
5739
schema=f'"{schema}"',
5840
**kw,
5941
)

0 commit comments

Comments
 (0)