Skip to content

Commit 70b55c9

Browse files
Only enable requester pays when necessary #400 (#402)
Only enable (Google) requester pays data access when the given DRS URIs require it and the platform TNU is running on supports it.
1 parent 68b7888 commit 70b55c9

File tree

2 files changed

+154
-18
lines changed

2 files changed

+154
-18
lines changed

terra_notebook_utils/drs.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from terra_notebook_utils import WORKSPACE_BUCKET, WORKSPACE_NAME, DRS_RESOLVER_URL, WORKSPACE_NAMESPACE, \
1111
WORKSPACE_GOOGLE_PROJECT
12-
from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA
13-
from terra_notebook_utils.utils import is_notebook
12+
from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA, ExecutionPlatform
13+
from terra_notebook_utils.utils import is_notebook, get_execution_context
1414
from terra_notebook_utils.http import http
1515
from terra_notebook_utils.blobstore.gs import GSBlob
1616
from terra_notebook_utils.blobstore.local import LocalBlob
@@ -26,16 +26,51 @@
2626
class DRSResolutionError(Exception):
2727
pass
2828

29+
class RequesterPaysNotSupported(Exception):
30+
pass
31+
2932
def _parse_gs_url(gs_url: str) -> Tuple[str, str]:
3033
if gs_url.startswith(_GS_SCHEMA):
3134
bucket_name, object_key = gs_url[len(_GS_SCHEMA):].split("/", 1)
3235
return bucket_name, object_key
3336
else:
3437
raise RuntimeError(f'Invalid gs url schema. {gs_url} does not start with {_GS_SCHEMA}')
3538

39+
40+
def is_requester_pays(drs_urls: Iterable[str]) -> bool:
41+
"""
42+
Identify if any of the given DRS URIs require Google requester pays access
43+
44+
:raises: RequesterPaysNotSupported
45+
"""
46+
for drs_url in drs_urls:
47+
# Currently (1/2023), Gen3-hosted AnVIL data in GCS is the only
48+
# DRS data requiring requester pays access.
49+
# Even this will end in when the AnVIL data is hosted in TDR.
50+
# Note: If the Gen3 AnVIL DRS URI format is retained by TDR, this function must be updated.
51+
52+
# The DRS specification (v1.2) states compact identifiers must be lowercase alphanumerical values:
53+
# https://ga4gh.github.io/data-repository-service-schemas/preview/release/drs-1.2.0/docs/#section/DRS-URIs
54+
# Yet for historical reasons the DRS URIs minted by Gen3 use uppercase.
55+
# Perform a case-insensitive comparison.
56+
anvil_drs_uri_prefix = "drs://dg.ANV0".lower()
57+
if drs_url.strip().lower().startswith(anvil_drs_uri_prefix):
58+
if get_execution_context().execution_platform == ExecutionPlatform.AZURE:
59+
raise RequesterPaysNotSupported(
60+
f"Requester pays data access is not supported on the Azure platform. Cannot access: {drs_url}"
61+
)
62+
else:
63+
return True
64+
return False
65+
66+
3667
@lru_cache()
3768
def enable_requester_pays(workspace_name: Optional[str]=WORKSPACE_NAME,
3869
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE):
70+
71+
assert get_execution_context().execution_platform != ExecutionPlatform.AZURE, \
72+
"Requester pays data access is not supported on the Terra Azure platform."
73+
3974
if not workspace_name:
4075
raise RuntimeError('Workspace name is not set. Please set the environment variable '
4176
'WORKSPACE_NAME with the name of a valid Terra Workspace.')
@@ -92,11 +127,14 @@ def access(drs_url: str,
92127
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE,
93128
billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT) -> str:
94129
"""Return a signed url for a drs:// URI, if available."""
95-
# We enable requester pays by specifying the workspace/namespace combo, not
96-
# with the billing project. Rawls then enables requester pays for the attached
97-
# project, but this won't work if a user specifies a project unattached to
98-
# the Terra workspace.
99-
enable_requester_pays(workspace_name, workspace_namespace)
130+
131+
if is_requester_pays([drs_url]):
132+
# We enable requester pays by specifying the workspace/namespace combo, not
133+
# with the billing project. Rawls then enables requester pays for the attached
134+
# project, but this won't work if a user specifies a project unattached to
135+
# the Terra workspace.
136+
enable_requester_pays(workspace_name, workspace_namespace)
137+
100138
info = get_drs_info(drs_url, access_url=True)
101139

102140
if info.access_url:
@@ -232,7 +270,9 @@ def head(drs_url: str,
232270
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE,
233271
billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT):
234272
"""Head a DRS object by byte."""
235-
enable_requester_pays(workspace_name, workspace_namespace)
273+
274+
if is_requester_pays([drs_url]):
275+
enable_requester_pays(workspace_name, workspace_namespace)
236276
try:
237277
blob = get_drs_blob(drs_url, billing_project)
238278
with blob.open(chunk_size=num_bytes) as fh:
@@ -295,7 +335,8 @@ def copy(drs_uri: str,
295335
"""Copy a DRS object to either the local filesystem, or to a Google Storage location if `dst` starts with
296336
"gs://".
297337
"""
298-
enable_requester_pays(workspace_name, workspace_namespace)
338+
if is_requester_pays([drs_uri]):
339+
enable_requester_pays(workspace_name, workspace_namespace)
299340
with DRSCopyClient(raise_on_error=True, indicator_type=indicator_type) as cc:
300341
cc.workspace = workspace_name
301342
cc.workspace_namespace = workspace_namespace
@@ -340,7 +381,8 @@ def copy_batch_urls(drs_urls: Iterable[str],
340381
indicator_type: Indicator = Indicator.notebook_bar if is_notebook() else Indicator.log,
341382
workspace_name: Optional[str] = WORKSPACE_NAME,
342383
workspace_namespace: Optional[str] = WORKSPACE_NAMESPACE):
343-
enable_requester_pays(workspace_name, workspace_namespace)
384+
if is_requester_pays(drs_urls):
385+
enable_requester_pays(workspace_name, workspace_namespace)
344386
with DRSCopyClient(indicator_type=indicator_type) as cc:
345387
cc.workspace = workspace_name
346388
cc.workspace_namespace = workspace_namespace
@@ -365,7 +407,11 @@ def copy_batch_manifest(manifest: List[Dict[str, str]],
365407
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE):
366408
from jsonschema import validate
367409
validate(instance=manifest, schema=manifest_schema)
368-
enable_requester_pays(workspace_name, workspace_namespace)
410+
411+
drs_uri_list = [item['drs_uri'] for item in manifest]
412+
if is_requester_pays(drs_uri_list):
413+
enable_requester_pays(workspace_name, workspace_namespace)
414+
369415
with DRSCopyClient(indicator_type=indicator_type) as cc:
370416
cc.workspace = workspace_name
371417
cc.workspace_namespace = workspace_namespace
@@ -381,7 +427,9 @@ def extract_tar_gz(drs_url: str,
381427
Default extraction is to the bucket for 'workspace'.
382428
"""
383429
dst = dst or f"gs://{workspace.get_workspace_bucket(workspace_name)}"
384-
enable_requester_pays(workspace_name, workspace_namespace)
430+
431+
if is_requester_pays([drs_url]):
432+
enable_requester_pays(workspace_name, workspace_namespace)
385433
blob = get_drs_blob(drs_url, billing_project)
386434
with blob.open() as fh:
387435
tar_gz.extract(fh, dst)

tests/test_drs.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,14 @@ def test_copy_to_bucket(self):
6969

7070

7171
class TestTerraNotebookUtilsDRS(SuppressWarningsMixin, unittest.TestCase):
72-
drs_url = "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35" # DRS response contains GS native url + credentials
73-
drs_url_signed = "drs://dg.4DFC:dg.4DFC/00040a6f-b7e5-4e5c-ab57-ee92a0ba8201" # DRS response contains signed URL
72+
# BDC data on Google
73+
# DRS response contains GS native url + credentials + signed URL if requested
74+
drs_url = "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35"
75+
# CRDC PDC data on AWS (only)
76+
# DRS response contains signed URL if requested
77+
drs_url_signed = "drs://dg.4DFC:dg.4DFC/00040a6f-b7e5-4e5c-ab57-ee92a0ba8201"
78+
# DRS response contains GS native url + credentials (no signed URL)
79+
drs_url_requester_pays = "drs://dg.ANV0/1b1ee6fc-6560-4b08-9c44-36d46bf4daf1"
7480
jade_dev_url = "drs://jade.datarepo-dev.broadinstitute.org/v1_0c86170e-312d-4b39-a0a4-2a2bfaa24c7a_" \
7581
"c0e40912-8b14-43f6-9a2f-b278144d0060"
7682

@@ -396,7 +402,15 @@ def test_extract_tar_gz(self):
396402
self.assertEqual(data, expected_data)
397403

398404
@testmode("workspace_access")
399-
def test_arg_propagation(self):
405+
def test_is_requester_pays(self):
406+
self.assertFalse(drs.is_requester_pays([self.drs_url]))
407+
self.assertFalse(drs.is_requester_pays([self.drs_url_signed]))
408+
self.assertTrue(drs.is_requester_pays([self.drs_url_requester_pays]))
409+
self.assertTrue(drs.is_requester_pays(
410+
[self.drs_url, self.drs_url_requester_pays, self.drs_url_signed]))
411+
412+
@testmode("workspace_access")
413+
def test_arg_propagation_and_enable_requester_pays(self):
400414
resp_json = mock.MagicMock(return_value={
401415
'googleServiceAccount': {'data': {'project_id': "foo"}},
402416
'dos': {'data_object': {'urls': [{'url': 'gs://asdf/asdf'}]}}
@@ -410,19 +424,93 @@ def test_arg_propagation(self):
410424
es.enter_context(mock.patch("terra_notebook_utils.drs.GSBlob.open"))
411425
es.enter_context(mock.patch("terra_notebook_utils.drs.http", post=requests_post))
412426
with mock.patch("terra_notebook_utils.drs.enable_requester_pays") as enable_requester_pays:
427+
with self.subTest("Access URL"):
428+
try:
429+
drs.access(self.drs_url_requester_pays)
430+
except Exception:
431+
pass # Ignore downstream error due to complexity of mocking
432+
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
413433
with self.subTest("Copy to local"):
434+
enable_requester_pays.reset_mock()
414435
with tempfile.NamedTemporaryFile() as tf:
415-
drs.copy(self.drs_url, tf.name)
436+
drs.copy(self.drs_url_requester_pays, tf.name)
416437
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
417438
with self.subTest("Copy to bucket"):
418439
enable_requester_pays.reset_mock()
419-
drs.copy(self.drs_url, "gs://some_bucket/some_key")
440+
drs.copy(self.drs_url_requester_pays, "gs://some_bucket/some_key")
420441
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
442+
with self.subTest("Copy batch urls"):
443+
enable_requester_pays.reset_mock()
444+
with tempfile.TemporaryDirectory() as td_name:
445+
drs.copy_batch_urls([self.drs_url, self.drs_url_requester_pays], td_name)
446+
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
447+
with self.subTest("Copy batch manifest"):
448+
enable_requester_pays.reset_mock()
449+
with tempfile.TemporaryDirectory() as td_name:
450+
manifest = [{"drs_uri": self.drs_url, "dst": td_name},
451+
{"drs_uri": self.drs_url_requester_pays, "dst": td_name}]
452+
drs.copy_batch_manifest(manifest)
453+
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
421454
with self.subTest("Extract tarball"):
422455
enable_requester_pays.reset_mock()
423-
drs.extract_tar_gz(self.drs_url)
456+
drs.extract_tar_gz(self.drs_url_requester_pays)
457+
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
458+
with self.subTest("Head"):
459+
enable_requester_pays.reset_mock()
460+
drs.head(self.drs_url_requester_pays)
424461
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
425462

463+
@testmode("workspace_access")
464+
def test_enable_requester_pays_not_called_when_not_necessary(self):
465+
resp_json = mock.MagicMock(return_value={
466+
'googleServiceAccount': {'data': {'project_id': "foo"}},
467+
'dos': {'data_object': {'urls': [{'url': 'gs://asdf/asdf'}]}}
468+
})
469+
requests_post = mock.MagicMock(return_value=mock.MagicMock(status_code=200, json=resp_json))
470+
with ExitStack() as es:
471+
es.enter_context(mock.patch("terra_notebook_utils.drs.gs.get_client"))
472+
es.enter_context(mock.patch("terra_notebook_utils.drs.tar_gz"))
473+
es.enter_context(mock.patch("terra_notebook_utils.blobstore.gs.GSBlob.download"))
474+
es.enter_context(mock.patch("terra_notebook_utils.drs.DRSCopyClient"))
475+
es.enter_context(mock.patch("terra_notebook_utils.drs.GSBlob.open"))
476+
es.enter_context(mock.patch("terra_notebook_utils.drs.http", post=requests_post))
477+
with mock.patch("terra_notebook_utils.drs.enable_requester_pays") as enable_requester_pays:
478+
with self.subTest("Access URL"):
479+
try:
480+
drs.access(self.drs_url)
481+
except Exception:
482+
pass # Ignore downstream error due to complexity of mocking
483+
enable_requester_pays.assert_not_called()
484+
with self.subTest("Copy to local"):
485+
enable_requester_pays.reset_mock()
486+
with tempfile.NamedTemporaryFile() as tf:
487+
drs.copy(self.drs_url, tf.name)
488+
enable_requester_pays.assert_not_called()
489+
with self.subTest("Copy to bucket"):
490+
enable_requester_pays.reset_mock()
491+
drs.copy(self.drs_url, "gs://some_bucket/some_key")
492+
enable_requester_pays.assert_not_called()
493+
with self.subTest("Copy batch urls"):
494+
enable_requester_pays.reset_mock()
495+
with tempfile.TemporaryDirectory() as td_name:
496+
drs.copy_batch_urls([self.drs_url, self.drs_url], td_name)
497+
enable_requester_pays.assert_not_called()
498+
with self.subTest("Copy batch manifest"):
499+
enable_requester_pays.reset_mock()
500+
with tempfile.TemporaryDirectory() as td_name:
501+
manifest = [{"drs_uri": self.drs_url, "dst": td_name},
502+
{"drs_uri": self.drs_url, "dst": td_name}]
503+
drs.copy_batch_manifest(manifest)
504+
enable_requester_pays.assert_not_called()
505+
with self.subTest("Extract tarball"):
506+
enable_requester_pays.reset_mock()
507+
drs.extract_tar_gz(self.drs_url)
508+
enable_requester_pays.assert_not_called()
509+
with self.subTest("Head"):
510+
enable_requester_pays.reset_mock()
511+
drs.head(self.drs_url)
512+
enable_requester_pays.assert_not_called()
513+
426514
# test for when we get everything what we wanted in drs_resolver response
427515
def test_drs_resolver_response(self):
428516
resp_json = mock.MagicMock(return_value=self.mock_jdr_response)

0 commit comments

Comments
 (0)