diff --git a/terra_notebook_utils/drs.py b/terra_notebook_utils/drs.py index 11f7f3a..50c3f95 100644 --- a/terra_notebook_utils/drs.py +++ b/terra_notebook_utils/drs.py @@ -9,8 +9,8 @@ from terra_notebook_utils import WORKSPACE_BUCKET, WORKSPACE_NAME, DRS_RESOLVER_URL, WORKSPACE_NAMESPACE, \ WORKSPACE_GOOGLE_PROJECT -from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA -from terra_notebook_utils.utils import is_notebook +from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA, ExecutionPlatform +from terra_notebook_utils.utils import is_notebook, get_execution_context from terra_notebook_utils.http_utils import http from terra_notebook_utils.blobstore.gs import GSBlob from terra_notebook_utils.blobstore.local import LocalBlob @@ -26,6 +26,9 @@ class DRSResolutionError(Exception): pass +class RequesterPaysNotSupported(Exception): + pass + def _parse_gs_url(gs_url: str) -> Tuple[str, str]: if gs_url.startswith(_GS_SCHEMA): bucket_name, object_key = gs_url[len(_GS_SCHEMA):].split("/", 1) @@ -33,9 +36,32 @@ def _parse_gs_url(gs_url: str) -> Tuple[str, str]: else: raise RuntimeError(f'Invalid gs url schema. {gs_url} does not start with {_GS_SCHEMA}') + +def is_requester_pays(drs_urls: Iterable[str]) -> bool: + """ Identify if any of the given DRS URIs require Google requester pays access """ + for drs_url in drs_urls: + # Currently (1/2023), Gen3-hosted AnVIL data in GCS is the only + # DRS data requiring requester pays access. + # Even this will end in when the AnVIL data is hosted in TDR. + # Note: If the Gen3 AnVIL DRS URI format is retained by TDR, this function must be updated. + ANVIL_DRS_URI_PREFIX = "drs://dg.ANV0" + if drs_url.strip().startswith(ANVIL_DRS_URI_PREFIX): + if get_execution_context().execution_platform == ExecutionPlatform.AZURE: + raise RequesterPaysNotSupported( + f"Requester pays data access is not supported on the Azure platform. Cannot access: {drs_url}" + ) + else: + return True + return False + + @lru_cache() def enable_requester_pays(workspace_name: Optional[str]=WORKSPACE_NAME, workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE): + + assert get_execution_context().execution_platform != ExecutionPlatform.AZURE, \ + "Requester pays data access is not supported on the Terra Azure platform." + if not workspace_name: raise RuntimeError('Workspace name is not set. Please set the environment variable ' 'WORKSPACE_NAME with the name of a valid Terra Workspace.') @@ -92,11 +118,14 @@ def access(drs_url: str, workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE, billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT) -> str: """Return a signed url for a drs:// URI, if available.""" - # We enable requester pays by specifying the workspace/namespace combo, not - # with the billing project. Rawls then enables requester pays for the attached - # project, but this won't work if a user specifies a project unattached to - # the Terra workspace. - enable_requester_pays(workspace_name, workspace_namespace) + + if is_requester_pays([drs_url]): + # We enable requester pays by specifying the workspace/namespace combo, not + # with the billing project. Rawls then enables requester pays for the attached + # project, but this won't work if a user specifies a project unattached to + # the Terra workspace. + enable_requester_pays(workspace_name, workspace_namespace) + info = get_drs_info(drs_url, access_url=True) if info.access_url: @@ -232,7 +261,9 @@ def head(drs_url: str, workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE, billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT): """Head a DRS object by byte.""" - enable_requester_pays(workspace_name, workspace_namespace) + + if is_requester_pays(drs_url): + enable_requester_pays(workspace_name, workspace_namespace) try: blob = get_drs_blob(drs_url, billing_project) with blob.open(chunk_size=num_bytes) as fh: @@ -295,7 +326,8 @@ def copy(drs_uri: str, """Copy a DRS object to either the local filesystem, or to a Google Storage location if `dst` starts with "gs://". """ - enable_requester_pays(workspace_name, workspace_namespace) + if is_requester_pays([drs_uri]): + enable_requester_pays(workspace_name, workspace_namespace) with DRSCopyClient(raise_on_error=True, indicator_type=indicator_type) as cc: cc.workspace = workspace_name cc.workspace_namespace = workspace_namespace @@ -340,7 +372,8 @@ def copy_batch_urls(drs_urls: Iterable[str], indicator_type: Indicator = Indicator.notebook_bar if is_notebook() else Indicator.log, workspace_name: Optional[str] = WORKSPACE_NAME, workspace_namespace: Optional[str] = WORKSPACE_NAMESPACE): - enable_requester_pays(workspace_name, workspace_namespace) + if is_requester_pays(drs_urls): + enable_requester_pays(workspace_name, workspace_namespace) with DRSCopyClient(indicator_type=indicator_type) as cc: cc.workspace = workspace_name cc.workspace_namespace = workspace_namespace @@ -365,7 +398,11 @@ def copy_batch_manifest(manifest: List[Dict[str, str]], workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE): from jsonschema import validate validate(instance=manifest, schema=manifest_schema) - enable_requester_pays(workspace_name, workspace_namespace) + + drs_uri_list = [item['drs_uri'] for item in manifest] + if is_requester_pays(drs_uri_list): + enable_requester_pays(workspace_name, workspace_namespace) + with DRSCopyClient(indicator_type=indicator_type) as cc: cc.workspace = workspace_name cc.workspace_namespace = workspace_namespace @@ -381,7 +418,9 @@ def extract_tar_gz(drs_url: str, Default extraction is to the bucket for 'workspace'. """ dst = dst or f"gs://{workspace.get_workspace_bucket(workspace_name)}" - enable_requester_pays(workspace_name, workspace_namespace) + + if is_requester_pays(drs_url): + enable_requester_pays(workspace_name, workspace_namespace) blob = get_drs_blob(drs_url, billing_project) with blob.open() as fh: tar_gz.extract(fh, dst)