Skip to content

Commit

Permalink
fix(ingest/tableau): retry on auth error for special case (datahub-pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate authored and chakru-r committed Jan 11, 2025
1 parent 91261a1 commit dc3804f
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 20 deletions.
68 changes: 50 additions & 18 deletions metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import logging
import re
import time
from collections import OrderedDict
from dataclasses import dataclass
from datetime import datetime
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field as dataclass_field
from datetime import datetime, timedelta, timezone
from functools import lru_cache
from typing import (
Any,
Expand Down Expand Up @@ -196,6 +196,11 @@
504, # Gateway Timeout
]

# From experience, this expiry time typically ranges from 50 minutes
# to 2 hours but might as well be configurable. We will allow upto
# 10 minutes of such expiry time
REGULAR_AUTH_EXPIRY_PERIOD = timedelta(minutes=10)

logger: logging.Logger = logging.getLogger(__name__)

# Replace / with |
Expand Down Expand Up @@ -637,6 +642,7 @@ class SiteIdContentUrl:
site_content_url: str


@dataclass
class TableauSourceReport(StaleEntityRemovalSourceReport):
get_all_datasources_query_failed: bool = False
num_get_datasource_query_failures: int = 0
Expand All @@ -653,7 +659,14 @@ class TableauSourceReport(StaleEntityRemovalSourceReport):
num_upstream_table_lineage_failed_parse_sql: int = 0
num_upstream_fine_grained_lineage_failed_parse_sql: int = 0
num_hidden_assets_skipped: int = 0
logged_in_user: List[UserInfo] = []
logged_in_user: List[UserInfo] = dataclass_field(default_factory=list)
last_authenticated_at: Optional[datetime] = None

num_expected_tableau_metadata_queries: int = 0
num_actual_tableau_metadata_queries: int = 0
tableau_server_error_stats: Dict[str, int] = dataclass_field(
default_factory=(lambda: defaultdict(int))
)


def report_user_role(report: TableauSourceReport, server: Server) -> None:
Expand Down Expand Up @@ -724,6 +737,7 @@ def _authenticate(self, site_content_url: str) -> None:
try:
logger.info(f"Authenticated to Tableau site: '{site_content_url}'")
self.server = self.config.make_tableau_client(site_content_url)
self.report.last_authenticated_at = datetime.now(timezone.utc)
report_user_role(report=self.report, server=self.server)
# Note that we're not catching ConfigurationError, since we want that to throw.
except ValueError as e:
Expand Down Expand Up @@ -807,10 +821,13 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
site_source = TableauSiteSource(
config=self.config,
ctx=self.ctx,
site=site
if site
else SiteIdContentUrl(
site_id=self.server.site_id, site_content_url=self.config.site
site=(
site
if site
else SiteIdContentUrl(
site_id=self.server.site_id,
site_content_url=self.config.site,
)
),
report=self.report,
server=self.server,
Expand Down Expand Up @@ -925,6 +942,7 @@ def _re_authenticate(self) -> None:
# Sign-in again may not be enough because Tableau sometimes caches invalid sessions
# so we need to recreate the Tableau Server object
self.server = self.config.make_tableau_client(self.site_content_url)
self.report.last_authenticated_at = datetime.now(timezone.utc)

def _populate_usage_stat_registry(self) -> None:
if self.server is None:
Expand Down Expand Up @@ -1190,6 +1208,7 @@ def get_connection_object_page(
)
try:
assert self.server is not None
self.report.num_actual_tableau_metadata_queries += 1
query_data = query_metadata_cursor_based_pagination(
server=self.server,
main_query=query,
Expand All @@ -1199,25 +1218,36 @@ def get_connection_object_page(
qry_filter=query_filter,
)

except REAUTHENTICATE_ERRORS:
if not retry_on_auth_error:
except REAUTHENTICATE_ERRORS as e:
self.report.tableau_server_error_stats[e.__class__.__name__] += 1
if not retry_on_auth_error or retries_remaining <= 0:
raise

# If ingestion has been running for over 2 hours, the Tableau
# temporary credentials will expire. If this happens, this exception
# will be thrown, and we need to re-authenticate and retry.
self._re_authenticate()
# We have been getting some irregular authorization errors like below well before the expected expiry time
# - within few seconds of initial authentication . We'll retry without re-auth for such cases.
# <class 'tableauserverclient.server.endpoint.exceptions.NonXMLResponseError'>:
# b'{"timestamp":"xxx","status":401,"error":"Unauthorized","path":"/relationship-service-war/graphql"}'
if self.report.last_authenticated_at and (
datetime.now(timezone.utc) - self.report.last_authenticated_at
> REGULAR_AUTH_EXPIRY_PERIOD
):
# If ingestion has been running for over 2 hours, the Tableau
# temporary credentials will expire. If this happens, this exception
# will be thrown, and we need to re-authenticate and retry.
self._re_authenticate()

return self.get_connection_object_page(
query=query,
connection_type=connection_type,
query_filter=query_filter,
fetch_size=fetch_size,
current_cursor=current_cursor,
retry_on_auth_error=False,
retry_on_auth_error=True,
retries_remaining=retries_remaining - 1,
)

except InternalServerError as ise:
self.report.tableau_server_error_stats[InternalServerError.__name__] += 1
# In some cases Tableau Server returns 504 error, which is a timeout error, so it worths to retry.
# Extended with other retryable errors.
if ise.code in RETRIABLE_ERROR_CODES:
Expand All @@ -1230,13 +1260,14 @@ def get_connection_object_page(
query_filter=query_filter,
fetch_size=fetch_size,
current_cursor=current_cursor,
retry_on_auth_error=False,
retry_on_auth_error=True,
retries_remaining=retries_remaining - 1,
)
else:
raise ise

except OSError:
self.report.tableau_server_error_stats[OSError.__name__] += 1
# In tableauseverclient 0.26 (which was yanked and released in 0.28 on 2023-10-04),
# the request logic was changed to use threads.
# https://github.com/tableau/server-client-python/commit/307d8a20a30f32c1ce615cca7c6a78b9b9bff081
Expand All @@ -1251,7 +1282,7 @@ def get_connection_object_page(
query_filter=query_filter,
fetch_size=fetch_size,
current_cursor=current_cursor,
retry_on_auth_error=False,
retry_on_auth_error=True,
retries_remaining=retries_remaining - 1,
)

Expand Down Expand Up @@ -1339,7 +1370,7 @@ def get_connection_object_page(
query_filter=query_filter,
fetch_size=fetch_size,
current_cursor=current_cursor,
retry_on_auth_error=False,
retry_on_auth_error=True,
retries_remaining=retries_remaining,
)
raise RuntimeError(f"Query {connection_type} error: {errors}")
Expand Down Expand Up @@ -1377,6 +1408,7 @@ def get_connection_objects(
while has_next_page:
filter_: str = make_filter(filter_page)

self.report.num_expected_tableau_metadata_queries += 1
(
connection_objects,
current_cursor,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import pathlib
from typing import Any, Dict, List, cast
from typing import Any, Dict, List, Union, cast
from unittest import mock

import pytest
Expand All @@ -13,10 +13,15 @@
GroupItem,
ProjectItem,
SiteItem,
UserItem,
ViewItem,
WorkbookItem,
)
from tableauserverclient.models.reference_item import ResourceReference
from tableauserverclient.server.endpoint.exceptions import (
NonXMLResponseError,
TableauError,
)

from datahub.emitter.mce_builder import DEFAULT_ENV, make_schema_field_urn
from datahub.emitter.mcp import MetadataChangeProposalWrapper
Expand Down Expand Up @@ -270,7 +275,7 @@ def side_effect_site_get_by_id(id, *arg, **kwargs):


def mock_sdk_client(
side_effect_query_metadata_response: List[dict],
side_effect_query_metadata_response: List[Union[dict, TableauError]],
datasources_side_effect: List[dict],
sign_out_side_effect: List[dict],
) -> mock.MagicMock:
Expand Down Expand Up @@ -1312,6 +1317,61 @@ def test_permission_warning(pytestconfig, tmp_path, mock_datahub_graph):
)


@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_retry_on_error(pytestconfig, tmp_path, mock_datahub_graph):
with mock.patch(
"datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider.DataHubGraph",
mock_datahub_graph,
) as mock_checkpoint:
mock_checkpoint.return_value = mock_datahub_graph

with mock.patch("datahub.ingestion.source.tableau.tableau.Server") as mock_sdk:
mock_client = mock_sdk_client(
side_effect_query_metadata_response=[
NonXMLResponseError(
"""{"timestamp":"xxx","status":401,"error":"Unauthorized","path":"/relationship-service-war/graphql"}"""
),
*mock_data(),
],
sign_out_side_effect=[{}],
datasources_side_effect=[{}],
)
mock_client.users = mock.Mock()
mock_client.users.get_by_id.side_effect = [
UserItem(
name="name", site_role=UserItem.Roles.SiteAdministratorExplorer
)
]
mock_sdk.return_value = mock_client

reporter = TableauSourceReport()
tableau_source = TableauSiteSource(
platform="tableau",
config=mock.MagicMock(),
ctx=mock.MagicMock(),
site=mock.MagicMock(spec=SiteItem, id="Site1", content_url="site1"),
server=mock_sdk.return_value,
report=reporter,
)

tableau_source.get_connection_object_page(
query=mock.MagicMock(),
connection_type=mock.MagicMock(),
query_filter=mock.MagicMock(),
current_cursor=None,
retries_remaining=1,
fetch_size=10,
)

assert reporter.num_actual_tableau_metadata_queries == 2
assert reporter.tableau_server_error_stats
assert reporter.tableau_server_error_stats["NonXMLResponseError"] == 1

assert reporter.warnings == []
assert reporter.failures == []


@freeze_time(FROZEN_TIME)
@pytest.mark.parametrize(
"extract_project_hierarchy, allowed_projects",
Expand Down

0 comments on commit dc3804f

Please sign in to comment.