Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 7104319

Browse files
committed
add support for bq service-account
1 parent b39aa17 commit 7104319

File tree

4 files changed

+75
-22
lines changed

4 files changed

+75
-22
lines changed

data_diff/dbt_parser.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from dbt.config.renderer import ProfileRenderer
1212

1313
from data_diff.errors import (
14-
DataDiffDbtBigQueryOauthOnlyError,
14+
DataDiffDbtBigQueryUnsupportedMethodError,
1515
DataDiffDbtConnectionNotImplementedError,
1616
DataDiffDbtCoreNoRunnerError,
1717
DataDiffDbtNoSuccessfulModelsInRunError,
@@ -204,7 +204,9 @@ def get_run_results_models(self):
204204
success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
205205
models = [self.manifest_obj.nodes.get(x) for x in success_models]
206206
if not models:
207-
raise DataDiffDbtNoSuccessfulModelsInRunError("Expected > 0 successful models runs from the last dbt command.")
207+
raise DataDiffDbtNoSuccessfulModelsInRunError(
208+
"Expected > 0 successful models runs from the last dbt command."
209+
)
208210

209211
return models
210212

@@ -295,17 +297,34 @@ def set_connection(self):
295297
else:
296298
raise DataDiffDbtSnowflakeSetConnectionError("Snowflake: unsupported auth method")
297299
elif conn_type == "bigquery":
300+
supported_methods = {
301+
"oauth": {
302+
"conn_info": {
303+
"driver": conn_type,
304+
"project": credentials.get("project"),
305+
"dataset": credentials.get("dataset"),
306+
}
307+
},
308+
"service-account": {
309+
"conn_info": {
310+
"driver": conn_type,
311+
"project": credentials.get("project"),
312+
"dataset": credentials.get("dataset"),
313+
"keyfile": credentials.get("keyfile"),
314+
}
315+
},
316+
}
298317
method = credentials.get("method")
299318
# there are many connection types https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#oauth-via-gcloud
300319
# this assumes that the user is auth'd via `gcloud auth application-default login`
301-
if method is None or method != "oauth":
302-
raise DataDiffDbtBigQueryOauthOnlyError("Oauth is the current method supported for Big Query.")
303-
conn_info = {
304-
"driver": conn_type,
305-
"project": credentials.get("project"),
306-
"dataset": credentials.get("dataset"),
307-
}
320+
if method not in supported_methods:
321+
raise DataDiffDbtBigQueryUnsupportedMethodError(
322+
f"Method: {method} is not in the current methods supported for Big Query ({supported_methods.keys()})."
323+
)
324+
325+
conn_info = supported_methods[method]["conn_info"]
308326
self.threads = credentials.get("threads")
327+
309328
elif conn_type == "duckdb":
310329
conn_info = {
311330
"driver": conn_type,
@@ -315,7 +334,9 @@ def set_connection(self):
315334
if (credentials.get("pass") is None and credentials.get("password") is None) or credentials.get(
316335
"method"
317336
) == "iam":
318-
raise DataDiffDbtRedshiftPasswordOnlyError("Only password authentication is currently supported for Redshift.")
337+
raise DataDiffDbtRedshiftPasswordOnlyError(
338+
"Only password authentication is currently supported for Redshift."
339+
)
319340
conn_info = {
320341
"driver": conn_type,
321342
"host": credentials.get("host"),

data_diff/errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class DataDiffDbtSnowflakeSetConnectionError(Exception):
2626
"Raised when a dbt snowflake profile has unexpected values."
2727

2828

29-
class DataDiffDbtBigQueryOauthOnlyError(Exception):
30-
"Raised when trying to use a method other than oauth with BigQuery."
29+
class DataDiffDbtBigQueryUnsupportedMethodError(Exception):
30+
"Raised when trying to use an unsupported connection with BigQuery."
3131

3232

3333
class DataDiffDbtRedshiftPasswordOnlyError(Exception):

data_diff/sqeleton/databases/bigquery.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,18 @@ def import_bigquery():
3636
return bigquery
3737

3838

39+
def import_bigquery_service_account():
40+
from google.oauth2 import service_account
41+
42+
return service_account
43+
44+
3945
class Mixin_MD5(AbstractMixin_MD5):
4046
def md5_as_int(self, s: str) -> str:
4147
return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)"
4248

4349

4450
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
45-
4651
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
4752
if coltype.rounds:
4853
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
@@ -144,8 +149,8 @@ class Dialect(BaseDialect, Mixin_Schema):
144149
"BOOL": Boolean,
145150
"JSON": JSON,
146151
}
147-
TYPE_ARRAY_RE = re.compile(r'ARRAY<(.+)>')
148-
TYPE_STRUCT_RE = re.compile(r'STRUCT<(.+)>')
152+
TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>")
153+
TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>")
149154
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample}
150155

151156
def random(self) -> str:
@@ -173,7 +178,6 @@ def parse_type(
173178
) -> ColType:
174179
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
175180
if isinstance(col_type, UnknownColType):
176-
177181
m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
178182
if m:
179183
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
@@ -207,9 +211,18 @@ class BigQuery(Database):
207211
dialect = Dialect()
208212

209213
def __init__(self, project, *, dataset, **kw):
214+
credentials = None
210215
bigquery = import_bigquery()
211216

212-
self._client = bigquery.Client(project, **kw)
217+
keyfile = kw.pop("keyfile", None)
218+
if keyfile:
219+
bigquery_service_account = import_bigquery_service_account()
220+
credentials = bigquery_service_account.Credentials.from_service_account_file(
221+
keyfile,
222+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
223+
)
224+
225+
self._client = bigquery.Client(project=project, credentials=credentials, **kw)
213226
self.project = project
214227
self.dataset = dataset
215228

tests/test_dbt.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from data_diff.cloud.datafold_api import TCloudApiOrgMeta
66
from data_diff.diff_tables import Algorithm
77
from data_diff.errors import (
8-
DataDiffDbtBigQueryOauthOnlyError,
8+
DataDiffDbtBigQueryUnsupportedMethodError,
99
DataDiffDbtConnectionNotImplementedError,
1010
DataDiffDbtCoreNoRunnerError,
1111
DataDiffDbtNoSuccessfulModelsInRunError,
@@ -276,7 +276,7 @@ def test_set_connection_snowflake_key_and_password(self):
276276

277277
self.assertNotIsInstance(mock_self.connection, dict)
278278

279-
def test_set_connection_bigquery_success(self):
279+
def test_set_connection_bigquery_oauth(self):
280280
expected_driver = "bigquery"
281281
expected_credentials = {
282282
"method": "oauth",
@@ -293,17 +293,36 @@ def test_set_connection_bigquery_success(self):
293293
self.assertEqual(mock_self.connection.get("project"), expected_credentials["project"])
294294
self.assertEqual(mock_self.connection.get("dataset"), expected_credentials["dataset"])
295295

296-
def test_set_connection_bigquery_not_oauth(self):
296+
def test_set_connection_bigquery_svc_account(self):
297297
expected_driver = "bigquery"
298298
expected_credentials = {
299-
"method": "not_oauth",
299+
"method": "service-account",
300+
"project": "a_project",
301+
"dataset": "a_dataset",
302+
"keyfile": "/some/path",
303+
}
304+
mock_self = Mock()
305+
mock_self.get_connection_creds.return_value = (expected_credentials, expected_driver)
306+
307+
DbtParser.set_connection(mock_self)
308+
309+
self.assertIsInstance(mock_self.connection, dict)
310+
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
311+
self.assertEqual(mock_self.connection.get("project"), expected_credentials["project"])
312+
self.assertEqual(mock_self.connection.get("dataset"), expected_credentials["dataset"])
313+
self.assertEqual(mock_self.connection.get("keyfile"), expected_credentials["keyfile"])
314+
315+
def test_set_connection_bigquery_not_supported(self):
316+
expected_driver = "bigquery"
317+
expected_credentials = {
318+
"method": "not_supported",
300319
"project": "a_project",
301320
"dataset": "a_dataset",
302321
}
303322

304323
mock_self = Mock()
305324
mock_self.get_connection_creds.return_value = (expected_credentials, expected_driver)
306-
with self.assertRaises(DataDiffDbtBigQueryOauthOnlyError):
325+
with self.assertRaises(DataDiffDbtBigQueryUnsupportedMethodError):
307326
DbtParser.set_connection(mock_self)
308327

309328
self.assertNotIsInstance(mock_self.connection, dict)

0 commit comments

Comments
 (0)