Skip to content

Commit

Permalink
Only enable requester pays when necessary #400
Browse files Browse the repository at this point in the history
Only enable (Google) requester pays data access when the
given DRS URIs require it and the platform TNU is running
on supports it.
  • Loading branch information
Michael Baumann committed Jan 18, 2023
1 parent 1fb315a commit f2a4af5
Showing 1 changed file with 55 additions and 12 deletions.
67 changes: 55 additions & 12 deletions terra_notebook_utils/drs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from terra_notebook_utils import WORKSPACE_BUCKET, WORKSPACE_NAME, DRS_RESOLVER_URL, WORKSPACE_NAMESPACE, \
WORKSPACE_GOOGLE_PROJECT
from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA
from terra_notebook_utils.utils import is_notebook
from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA, ExecutionPlatform
from terra_notebook_utils.utils import is_notebook, get_execution_context
from terra_notebook_utils.http_utils import http
from terra_notebook_utils.blobstore.gs import GSBlob
from terra_notebook_utils.blobstore.local import LocalBlob
Expand All @@ -26,16 +26,46 @@
class DRSResolutionError(Exception):
pass

class RequesterPaysNotSupported(Exception):
pass

def _parse_gs_url(gs_url: str) -> Tuple[str, str]:
if gs_url.startswith(_GS_SCHEMA):
bucket_name, object_key = gs_url[len(_GS_SCHEMA):].split("/", 1)
return bucket_name, object_key
else:
raise RuntimeError(f'Invalid gs url schema. {gs_url} does not start with {_GS_SCHEMA}')


def is_requester_pays(drs_urls: Iterable[str]) -> bool:
"""
Identify if any of the given DRS URIs require Google requester pays access
:raises: RequesterPaysNotSupported
"""
for drs_url in drs_urls:
# Currently (1/2023), Gen3-hosted AnVIL data in GCS is the only
# DRS data requiring requester pays access.
# Even this will end in when the AnVIL data is hosted in TDR.
# Note: If the Gen3 AnVIL DRS URI format is retained by TDR, this function must be updated.
ANVIL_DRS_URI_PREFIX = "drs://dg.ANV0"
if drs_url.strip().startswith(ANVIL_DRS_URI_PREFIX):
if get_execution_context().execution_platform == ExecutionPlatform.AZURE:
raise RequesterPaysNotSupported(
f"Requester pays data access is not supported on the Azure platform. Cannot access: {drs_url}"
)
else:
return True
return False


@lru_cache()
def enable_requester_pays(workspace_name: Optional[str]=WORKSPACE_NAME,
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE):

assert get_execution_context().execution_platform != ExecutionPlatform.AZURE, \
"Requester pays data access is not supported on the Terra Azure platform."

if not workspace_name:
raise RuntimeError('Workspace name is not set. Please set the environment variable '
'WORKSPACE_NAME with the name of a valid Terra Workspace.')
Expand Down Expand Up @@ -92,11 +122,14 @@ def access(drs_url: str,
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE,
billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT) -> str:
"""Return a signed url for a drs:// URI, if available."""
# We enable requester pays by specifying the workspace/namespace combo, not
# with the billing project. Rawls then enables requester pays for the attached
# project, but this won't work if a user specifies a project unattached to
# the Terra workspace.
enable_requester_pays(workspace_name, workspace_namespace)

if is_requester_pays([drs_url]):
# We enable requester pays by specifying the workspace/namespace combo, not
# with the billing project. Rawls then enables requester pays for the attached
# project, but this won't work if a user specifies a project unattached to
# the Terra workspace.
enable_requester_pays(workspace_name, workspace_namespace)

info = get_drs_info(drs_url, access_url=True)

if info.access_url:
Expand Down Expand Up @@ -232,7 +265,9 @@ def head(drs_url: str,
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE,
billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT):
"""Head a DRS object by byte."""
enable_requester_pays(workspace_name, workspace_namespace)

if is_requester_pays(drs_url):
enable_requester_pays(workspace_name, workspace_namespace)
try:
blob = get_drs_blob(drs_url, billing_project)
with blob.open(chunk_size=num_bytes) as fh:
Expand Down Expand Up @@ -295,7 +330,8 @@ def copy(drs_uri: str,
"""Copy a DRS object to either the local filesystem, or to a Google Storage location if `dst` starts with
"gs://".
"""
enable_requester_pays(workspace_name, workspace_namespace)
if is_requester_pays([drs_uri]):
enable_requester_pays(workspace_name, workspace_namespace)
with DRSCopyClient(raise_on_error=True, indicator_type=indicator_type) as cc:
cc.workspace = workspace_name
cc.workspace_namespace = workspace_namespace
Expand Down Expand Up @@ -340,7 +376,8 @@ def copy_batch_urls(drs_urls: Iterable[str],
indicator_type: Indicator = Indicator.notebook_bar if is_notebook() else Indicator.log,
workspace_name: Optional[str] = WORKSPACE_NAME,
workspace_namespace: Optional[str] = WORKSPACE_NAMESPACE):
enable_requester_pays(workspace_name, workspace_namespace)
if is_requester_pays(drs_urls):
enable_requester_pays(workspace_name, workspace_namespace)
with DRSCopyClient(indicator_type=indicator_type) as cc:
cc.workspace = workspace_name
cc.workspace_namespace = workspace_namespace
Expand All @@ -365,7 +402,11 @@ def copy_batch_manifest(manifest: List[Dict[str, str]],
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE):
from jsonschema import validate
validate(instance=manifest, schema=manifest_schema)
enable_requester_pays(workspace_name, workspace_namespace)

drs_uri_list = [item['drs_uri'] for item in manifest]
if is_requester_pays(drs_uri_list):
enable_requester_pays(workspace_name, workspace_namespace)

with DRSCopyClient(indicator_type=indicator_type) as cc:
cc.workspace = workspace_name
cc.workspace_namespace = workspace_namespace
Expand All @@ -381,7 +422,9 @@ def extract_tar_gz(drs_url: str,
Default extraction is to the bucket for 'workspace'.
"""
dst = dst or f"gs://{workspace.get_workspace_bucket(workspace_name)}"
enable_requester_pays(workspace_name, workspace_namespace)

if is_requester_pays(drs_url):
enable_requester_pays(workspace_name, workspace_namespace)
blob = get_drs_blob(drs_url, billing_project)
with blob.open() as fh:
tar_gz.extract(fh, dst)
Expand Down

0 comments on commit f2a4af5

Please sign in to comment.