Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 785d076

Browse files
authored
Merge pull request #609 from datafold/DX-810
--dbt add support for BQ service-account
2 parents 0d624fa + 1305871 commit 785d076

File tree

4 files changed

+55
-15
lines changed

4 files changed

+55
-15
lines changed

data_diff/dbt_parser.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from dbt.config.renderer import ProfileRenderer
1212

1313
from data_diff.errors import (
14-
DataDiffDbtBigQueryOauthOnlyError,
14+
DataDiffDbtBigQueryUnsupportedMethodError,
1515
DataDiffDbtConnectionNotImplementedError,
1616
DataDiffDbtCoreNoRunnerError,
1717
DataDiffDbtNoSuccessfulModelsInRunError,
@@ -319,17 +319,25 @@ def set_connection(self):
319319
else:
320320
raise DataDiffDbtSnowflakeSetConnectionError("Snowflake: unsupported auth method")
321321
elif conn_type == "bigquery":
322+
supported_methods = ["oauth", "service-account"]
322323
method = credentials.get("method")
323324
# there are many connection types https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#oauth-via-gcloud
324325
# this assumes that the user is auth'd via `gcloud auth application-default login`
325-
if method is None or method != "oauth":
326-
raise DataDiffDbtBigQueryOauthOnlyError("Oauth is the current method supported for Big Query.")
326+
if method not in supported_methods:
327+
raise DataDiffDbtBigQueryUnsupportedMethodError(
328+
f"Method: {method} is not in the current methods supported for Big Query ({supported_methods})."
329+
)
330+
327331
conn_info = {
328332
"driver": conn_type,
329333
"project": credentials.get("project"),
330334
"dataset": credentials.get("dataset"),
331335
}
336+
332337
self.threads = credentials.get("threads")
338+
if method == supported_methods[1]:
339+
conn_info["keyfile"] = credentials.get("keyfile")
340+
333341
elif conn_type == "duckdb":
334342
conn_info = {
335343
"driver": conn_type,

data_diff/errors.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class DataDiffDbtSnowflakeSetConnectionError(Exception):
2626
"Raised when a dbt snowflake profile has unexpected values."
2727

2828

29-
class DataDiffDbtBigQueryOauthOnlyError(Exception):
30-
"Raised when trying to use a method other than oauth with BigQuery."
29+
class DataDiffDbtBigQueryUnsupportedMethodError(Exception):
30+
"Raised when trying to use an unsupported connection with BigQuery."
3131

3232

3333
class DataDiffDbtRedshiftPasswordOnlyError(Exception):

data_diff/sqeleton/databases/bigquery.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,18 @@ def import_bigquery():
3636
return bigquery
3737

3838

39+
def import_bigquery_service_account():
40+
from google.oauth2 import service_account
41+
42+
return service_account
43+
44+
3945
class Mixin_MD5(AbstractMixin_MD5):
4046
def md5_as_int(self, s: str) -> str:
4147
return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)"
4248

4349

4450
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
45-
4651
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
4752
if coltype.rounds:
4853
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
@@ -144,8 +149,8 @@ class Dialect(BaseDialect, Mixin_Schema):
144149
"BOOL": Boolean,
145150
"JSON": JSON,
146151
}
147-
TYPE_ARRAY_RE = re.compile(r'ARRAY<(.+)>')
148-
TYPE_STRUCT_RE = re.compile(r'STRUCT<(.+)>')
152+
TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>")
153+
TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>")
149154
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample}
150155

151156
def random(self) -> str:
@@ -173,7 +178,6 @@ def parse_type(
173178
) -> ColType:
174179
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
175180
if isinstance(col_type, UnknownColType):
176-
177181
m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
178182
if m:
179183
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
@@ -207,9 +211,18 @@ class BigQuery(Database):
207211
dialect = Dialect()
208212

209213
def __init__(self, project, *, dataset, **kw):
214+
credentials = None
210215
bigquery = import_bigquery()
211216

212-
self._client = bigquery.Client(project, **kw)
217+
keyfile = kw.pop("keyfile", None)
218+
if keyfile:
219+
bigquery_service_account = import_bigquery_service_account()
220+
credentials = bigquery_service_account.Credentials.from_service_account_file(
221+
keyfile,
222+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
223+
)
224+
225+
self._client = bigquery.Client(project=project, credentials=credentials, **kw)
213226
self.project = project
214227
self.dataset = dataset
215228

tests/test_dbt.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from data_diff.cloud.datafold_api import TCloudApiOrgMeta
66
from data_diff.diff_tables import Algorithm
77
from data_diff.errors import (
8+
DataDiffDbtBigQueryUnsupportedMethodError,
89
DataDiffCustomSchemaNoConfigError,
9-
DataDiffDbtBigQueryOauthOnlyError,
1010
DataDiffDbtConnectionNotImplementedError,
1111
DataDiffDbtCoreNoRunnerError,
1212
DataDiffDbtNoSuccessfulModelsInRunError,
@@ -271,7 +271,7 @@ def test_set_connection_snowflake_key_and_password(self):
271271

272272
self.assertNotIsInstance(mock_self.connection, dict)
273273

274-
def test_set_connection_bigquery_success(self):
274+
def test_set_connection_bigquery_oauth(self):
275275
expected_driver = "bigquery"
276276
expected_credentials = {
277277
"method": "oauth",
@@ -288,17 +288,36 @@ def test_set_connection_bigquery_success(self):
288288
self.assertEqual(mock_self.connection.get("project"), expected_credentials["project"])
289289
self.assertEqual(mock_self.connection.get("dataset"), expected_credentials["dataset"])
290290

291-
def test_set_connection_bigquery_not_oauth(self):
291+
def test_set_connection_bigquery_svc_account(self):
292292
expected_driver = "bigquery"
293293
expected_credentials = {
294-
"method": "not_oauth",
294+
"method": "service-account",
295+
"project": "a_project",
296+
"dataset": "a_dataset",
297+
"keyfile": "/some/path",
298+
}
299+
mock_self = Mock()
300+
mock_self.get_connection_creds.return_value = (expected_credentials, expected_driver)
301+
302+
DbtParser.set_connection(mock_self)
303+
304+
self.assertIsInstance(mock_self.connection, dict)
305+
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
306+
self.assertEqual(mock_self.connection.get("project"), expected_credentials["project"])
307+
self.assertEqual(mock_self.connection.get("dataset"), expected_credentials["dataset"])
308+
self.assertEqual(mock_self.connection.get("keyfile"), expected_credentials["keyfile"])
309+
310+
def test_set_connection_bigquery_not_supported(self):
311+
expected_driver = "bigquery"
312+
expected_credentials = {
313+
"method": "not_supported",
295314
"project": "a_project",
296315
"dataset": "a_dataset",
297316
}
298317

299318
mock_self = Mock()
300319
mock_self.get_connection_creds.return_value = (expected_credentials, expected_driver)
301-
with self.assertRaises(DataDiffDbtBigQueryOauthOnlyError):
320+
with self.assertRaises(DataDiffDbtBigQueryUnsupportedMethodError):
302321
DbtParser.set_connection(mock_self)
303322

304323
self.assertNotIsInstance(mock_self.connection, dict)

0 commit comments

Comments
 (0)