Skip to content

Commit 2e20092

Browse files
authored
feat: add retry timeout (#26)
* feat: add retry timeout * feat: upgrade bigquery client version + add timeout as a predicate to retry * feat: rename retry config env * feat: include SSLError as error worth to retry * feat: add log level config * fix: assigned retry object
1 parent fe52f52 commit 2e20092

File tree

5 files changed

+57
-11
lines changed

5 files changed

+57
-11
lines changed

task/bq2bq/executor/bumblebee/bigquery_service.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from abc import ABC, abstractmethod
55

66
import google as google
7+
import requests.exceptions
78
from google.api_core.exceptions import BadRequest, Forbidden
9+
from google.api_core.retry import if_exception_type, if_transient_error
810
from google.cloud import bigquery
911
from google.cloud.bigquery.job import QueryJobConfig, CreateDisposition
1012
from google.cloud.bigquery.schema import _parse_schema_resource
@@ -50,17 +52,28 @@ def delete_table(self, full_table_name):
5052
def get_table(self, full_table_name):
5153
pass
5254

55+
def if_exception_funcs(fn_origin, fn_additional):
56+
def if_exception_func_predicate(exception):
57+
return fn_origin(exception) or fn_additional(exception)
58+
return if_exception_func_predicate
5359

5460
class BigqueryService(BaseBigqueryService):
5561

56-
def __init__(self, client, labels, writer, on_job_finish = None, on_job_register = None):
62+
def __init__(self, client, labels, writer, retry_timeout = None, on_job_finish = None, on_job_register = None):
5763
"""
5864
5965
:rtype:
6066
"""
6167
self.client = client
6268
self.labels = labels
6369
self.writer = writer
70+
if_additional_transient_error = if_exception_type(
71+
requests.exceptions.Timeout,
72+
requests.exceptions.SSLError,
73+
)
74+
predicate = if_exception_funcs(if_transient_error, if_additional_transient_error)
75+
retry = bigquery.DEFAULT_RETRY.with_deadline(retry_timeout) if retry_timeout else bigquery.DEFAULT_RETRY
76+
self.retry = retry.with_predicate(predicate)
6477
self.on_job_finish = on_job_finish
6578
self.on_job_register = on_job_register
6679

@@ -74,7 +87,8 @@ def execute_query(self, query):
7487

7588
logger.info("executing query")
7689
query_job = self.client.query(query=query,
77-
job_config=query_job_config)
90+
job_config=query_job_config,
91+
retry=self.retry)
7892
logger.info("Job {} is initially in state {} of {} project".format(query_job.job_id, query_job.state,
7993
query_job.project))
8094

@@ -125,7 +139,9 @@ def transform_load(self,
125139
query_job_config.destination = table_ref
126140

127141
logger.info("transform load")
128-
query_job = self.client.query(query=query, job_config=query_job_config)
142+
query_job = self.client.query(query=query,
143+
job_config=query_job_config,
144+
retry=self.retry)
129145
logger.info("Job {} is initially in state {} of {} project".format(query_job.job_id, query_job.state,
130146
query_job.project))
131147

@@ -183,7 +199,7 @@ def create_bigquery_service(task_config: TaskConfigFromEnv, labels, writer, on_j
183199
default_query_job_config.priority = task_config.query_priority
184200
default_query_job_config.allow_field_addition = task_config.allow_field_addition
185201
client = bigquery.Client(project=task_config.execution_project, credentials=credentials, default_query_job_config=default_query_job_config)
186-
return BigqueryService(client, labels, writer, on_job_finish=on_job_finish, on_job_register=on_job_register)
202+
return BigqueryService(client, labels, writer, retry_timeout=task_config.retry_timeout, on_job_finish=on_job_finish, on_job_register=on_job_register)
187203

188204

189205
def _get_bigquery_credentials():

task/bq2bq/executor/bumblebee/config.py

+14
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def __init__(self):
126126
self._use_spillover = _bool_from_str(get_env_config("USE_SPILLOVER", default="true"))
127127
self._concurrency = _validate_greater_than_zero(int(get_env_config("CONCURRENCY", default=1)))
128128
self._allow_field_addition = _bool_from_str(get_env_config("ALLOW_FIELD_ADDITION", default="false"))
129+
self._retry_timeout = get_env_config("RETRY_TIMEOUT_IN_SECONDS", default=None)
129130

130131
@property
131132
def destination_project(self) -> str:
@@ -178,6 +179,12 @@ def timezone(self):
178179
def concurrency(self) -> int:
179180
return self._concurrency
180181

182+
@property
183+
def retry_timeout(self) -> Optional[float]:
184+
if self._retry_timeout:
185+
return float(self._retry_timeout)
186+
return None
187+
181188
def print(self):
182189
logger.info("task config:\n{}".format(
183190
"\n".join([
@@ -348,6 +355,7 @@ def __init__(self, raw_properties):
348355
self._use_spillover = _bool_from_str(self._get_property_or_default("USE_SPILLOVER", "true"))
349356
self._concurrency = _validate_greater_than_zero(int(self._get_property_or_default("CONCURRENCY", 1)))
350357
self._allow_field_addition = _bool_from_str(self._get_property_or_default("ALLOW_FIELD_ADDITION", "false"))
358+
self._retry_timeout = self._get_property_or_default("RETRY_TIMEOUT_IN_SECONDS", None)
351359

352360
@property
353361
def sql_type(self) -> str:
@@ -412,6 +420,12 @@ def filter_expression(self) -> str:
412420
def allow_field_addition(self) -> bool:
413421
return self._allow_field_addition
414422

423+
@property
424+
def retry_timeout(self) -> Optional[float]:
425+
if self._retry_timeout:
426+
return float(self._retry_timeout)
427+
return None
428+
415429
def print(self):
416430
logger.info("task config:\n{}".format(
417431
"\n".join([

task/bq2bq/executor/bumblebee/log.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import sys
22
import logging
3+
import os
34

5+
def get_log_level():
6+
log_level = str(os.environ.get("LOG_LEVEL", default="INFO")).upper()
7+
log_level = log_level if log_level in logging._nameToLevel else "INFO"
8+
return logging._nameToLevel.get(log_level)
49

510
def get_logger(name: str):
611
logger = logging.getLogger(name)
712
logformat = "[%(asctime)s] %(levelname)s:%(name)s: %(message)s"
8-
logging.basicConfig(level=logging.INFO, stream=sys.stdout,
13+
logging.basicConfig(level=get_log_level(), stream=sys.stdout,
914
format=logformat, datefmt="%Y-%m-%d %H:%M:%S")
1015

1116
return logger

task/bq2bq/executor/requirements.txt

+7-6
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ cachetools==4.1.1
22
certifi==2020.6.20
33
chardet==3.0.4
44
google==3.0.0
5-
google-api-core==1.21.0
6-
google-auth==1.18.0
7-
google-cloud-bigquery==1.25.0
8-
google-cloud-core==1.3.0
9-
google-resumable-media==0.5.1
10-
googleapis-common-protos==1.52.0
5+
google-api-core==2.8.0
6+
google-auth==2.29.0
7+
google-cloud-bigquery==1.28.3
8+
google-cloud-core==2.4.1
9+
google-crc32c==1.5.0
10+
google-resumable-media==1.3.3
11+
googleapis-common-protos==1.56.0
1112
idna==2.10
1213
iso8601==0.1.12
1314
protobuf==3.12.2

task/bq2bq/executor/tests/test_config.py

+10
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,16 @@ def test_concurrency(self):
156156

157157
self.assertEqual(config.concurrency, 2)
158158

159+
def test_retry_timeout(self):
160+
self.set_vars_with_default()
161+
config = TaskConfigFromEnv()
162+
self.assertEqual(config.retry_timeout, None)
163+
164+
self.set_vars_with_default()
165+
os.environ['RETRY_TIMEOUT_IN_SECONDS'] = "120.0"
166+
config = TaskConfigFromEnv()
167+
self.assertEqual(config.retry_timeout, 120.0)
168+
159169
def test_concurrency_should_not_zero_exception(self):
160170
self.set_vars_with_default()
161171
os.environ['CONCURRENCY'] = "0"

0 commit comments

Comments
 (0)