Skip to content

Commit

Permalink
Only enable requester pays when necessary #400 (#402)
Browse files Browse the repository at this point in the history
Only enable (Google) requester pays data access when the
given DRS URIs require it and the platform TNU is running
on supports it.
  • Loading branch information
mbaumann-broad authored Jan 25, 2023
1 parent 68b7888 commit 70b55c9
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 18 deletions.
72 changes: 60 additions & 12 deletions terra_notebook_utils/drs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from terra_notebook_utils import WORKSPACE_BUCKET, WORKSPACE_NAME, DRS_RESOLVER_URL, WORKSPACE_NAMESPACE, \
WORKSPACE_GOOGLE_PROJECT
from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA
from terra_notebook_utils.utils import is_notebook
from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA, ExecutionPlatform
from terra_notebook_utils.utils import is_notebook, get_execution_context
from terra_notebook_utils.http import http
from terra_notebook_utils.blobstore.gs import GSBlob
from terra_notebook_utils.blobstore.local import LocalBlob
Expand All @@ -26,16 +26,51 @@
class DRSResolutionError(Exception):
pass

class RequesterPaysNotSupported(Exception):
pass

def _parse_gs_url(gs_url: str) -> Tuple[str, str]:
if gs_url.startswith(_GS_SCHEMA):
bucket_name, object_key = gs_url[len(_GS_SCHEMA):].split("/", 1)
return bucket_name, object_key
else:
raise RuntimeError(f'Invalid gs url schema. {gs_url} does not start with {_GS_SCHEMA}')


def is_requester_pays(drs_urls: Iterable[str]) -> bool:
"""
Identify if any of the given DRS URIs require Google requester pays access
:raises: RequesterPaysNotSupported
"""
for drs_url in drs_urls:
# Currently (1/2023), Gen3-hosted AnVIL data in GCS is the only
# DRS data requiring requester pays access.
# Even this will end in when the AnVIL data is hosted in TDR.
# Note: If the Gen3 AnVIL DRS URI format is retained by TDR, this function must be updated.

# The DRS specification (v1.2) states compact identifiers must be lowercase alphanumerical values:
# https://ga4gh.github.io/data-repository-service-schemas/preview/release/drs-1.2.0/docs/#section/DRS-URIs
# Yet for historical reasons the DRS URIs minted by Gen3 use uppercase.
# Perform a case-insensitive comparison.
anvil_drs_uri_prefix = "drs://dg.ANV0".lower()
if drs_url.strip().lower().startswith(anvil_drs_uri_prefix):
if get_execution_context().execution_platform == ExecutionPlatform.AZURE:
raise RequesterPaysNotSupported(
f"Requester pays data access is not supported on the Azure platform. Cannot access: {drs_url}"
)
else:
return True
return False


@lru_cache()
def enable_requester_pays(workspace_name: Optional[str]=WORKSPACE_NAME,
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE):

assert get_execution_context().execution_platform != ExecutionPlatform.AZURE, \
"Requester pays data access is not supported on the Terra Azure platform."

if not workspace_name:
raise RuntimeError('Workspace name is not set. Please set the environment variable '
'WORKSPACE_NAME with the name of a valid Terra Workspace.')
Expand Down Expand Up @@ -92,11 +127,14 @@ def access(drs_url: str,
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE,
billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT) -> str:
"""Return a signed url for a drs:// URI, if available."""
# We enable requester pays by specifying the workspace/namespace combo, not
# with the billing project. Rawls then enables requester pays for the attached
# project, but this won't work if a user specifies a project unattached to
# the Terra workspace.
enable_requester_pays(workspace_name, workspace_namespace)

if is_requester_pays([drs_url]):
# We enable requester pays by specifying the workspace/namespace combo, not
# with the billing project. Rawls then enables requester pays for the attached
# project, but this won't work if a user specifies a project unattached to
# the Terra workspace.
enable_requester_pays(workspace_name, workspace_namespace)

info = get_drs_info(drs_url, access_url=True)

if info.access_url:
Expand Down Expand Up @@ -232,7 +270,9 @@ def head(drs_url: str,
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE,
billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT):
"""Head a DRS object by byte."""
enable_requester_pays(workspace_name, workspace_namespace)

if is_requester_pays([drs_url]):
enable_requester_pays(workspace_name, workspace_namespace)
try:
blob = get_drs_blob(drs_url, billing_project)
with blob.open(chunk_size=num_bytes) as fh:
Expand Down Expand Up @@ -295,7 +335,8 @@ def copy(drs_uri: str,
"""Copy a DRS object to either the local filesystem, or to a Google Storage location if `dst` starts with
"gs://".
"""
enable_requester_pays(workspace_name, workspace_namespace)
if is_requester_pays([drs_uri]):
enable_requester_pays(workspace_name, workspace_namespace)
with DRSCopyClient(raise_on_error=True, indicator_type=indicator_type) as cc:
cc.workspace = workspace_name
cc.workspace_namespace = workspace_namespace
Expand Down Expand Up @@ -340,7 +381,8 @@ def copy_batch_urls(drs_urls: Iterable[str],
indicator_type: Indicator = Indicator.notebook_bar if is_notebook() else Indicator.log,
workspace_name: Optional[str] = WORKSPACE_NAME,
workspace_namespace: Optional[str] = WORKSPACE_NAMESPACE):
enable_requester_pays(workspace_name, workspace_namespace)
if is_requester_pays(drs_urls):
enable_requester_pays(workspace_name, workspace_namespace)
with DRSCopyClient(indicator_type=indicator_type) as cc:
cc.workspace = workspace_name
cc.workspace_namespace = workspace_namespace
Expand All @@ -365,7 +407,11 @@ def copy_batch_manifest(manifest: List[Dict[str, str]],
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE):
from jsonschema import validate
validate(instance=manifest, schema=manifest_schema)
enable_requester_pays(workspace_name, workspace_namespace)

drs_uri_list = [item['drs_uri'] for item in manifest]
if is_requester_pays(drs_uri_list):
enable_requester_pays(workspace_name, workspace_namespace)

with DRSCopyClient(indicator_type=indicator_type) as cc:
cc.workspace = workspace_name
cc.workspace_namespace = workspace_namespace
Expand All @@ -381,7 +427,9 @@ def extract_tar_gz(drs_url: str,
Default extraction is to the bucket for 'workspace'.
"""
dst = dst or f"gs://{workspace.get_workspace_bucket(workspace_name)}"
enable_requester_pays(workspace_name, workspace_namespace)

if is_requester_pays([drs_url]):
enable_requester_pays(workspace_name, workspace_namespace)
blob = get_drs_blob(drs_url, billing_project)
with blob.open() as fh:
tar_gz.extract(fh, dst)
Expand Down
100 changes: 94 additions & 6 deletions tests/test_drs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,14 @@ def test_copy_to_bucket(self):


class TestTerraNotebookUtilsDRS(SuppressWarningsMixin, unittest.TestCase):
drs_url = "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35" # DRS response contains GS native url + credentials
drs_url_signed = "drs://dg.4DFC:dg.4DFC/00040a6f-b7e5-4e5c-ab57-ee92a0ba8201" # DRS response contains signed URL
# BDC data on Google
# DRS response contains GS native url + credentials + signed URL if requested
drs_url = "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35"
# CRDC PDC data on AWS (only)
# DRS response contains signed URL if requested
drs_url_signed = "drs://dg.4DFC:dg.4DFC/00040a6f-b7e5-4e5c-ab57-ee92a0ba8201"
# DRS response contains GS native url + credentials (no signed URL)
drs_url_requester_pays = "drs://dg.ANV0/1b1ee6fc-6560-4b08-9c44-36d46bf4daf1"
jade_dev_url = "drs://jade.datarepo-dev.broadinstitute.org/v1_0c86170e-312d-4b39-a0a4-2a2bfaa24c7a_" \
"c0e40912-8b14-43f6-9a2f-b278144d0060"

Expand Down Expand Up @@ -396,7 +402,15 @@ def test_extract_tar_gz(self):
self.assertEqual(data, expected_data)

@testmode("workspace_access")
def test_arg_propagation(self):
def test_is_requester_pays(self):
self.assertFalse(drs.is_requester_pays([self.drs_url]))
self.assertFalse(drs.is_requester_pays([self.drs_url_signed]))
self.assertTrue(drs.is_requester_pays([self.drs_url_requester_pays]))
self.assertTrue(drs.is_requester_pays(
[self.drs_url, self.drs_url_requester_pays, self.drs_url_signed]))

@testmode("workspace_access")
def test_arg_propagation_and_enable_requester_pays(self):
resp_json = mock.MagicMock(return_value={
'googleServiceAccount': {'data': {'project_id': "foo"}},
'dos': {'data_object': {'urls': [{'url': 'gs://asdf/asdf'}]}}
Expand All @@ -410,19 +424,93 @@ def test_arg_propagation(self):
es.enter_context(mock.patch("terra_notebook_utils.drs.GSBlob.open"))
es.enter_context(mock.patch("terra_notebook_utils.drs.http", post=requests_post))
with mock.patch("terra_notebook_utils.drs.enable_requester_pays") as enable_requester_pays:
with self.subTest("Access URL"):
try:
drs.access(self.drs_url_requester_pays)
except Exception:
pass # Ignore downstream error due to complexity of mocking
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
with self.subTest("Copy to local"):
enable_requester_pays.reset_mock()
with tempfile.NamedTemporaryFile() as tf:
drs.copy(self.drs_url, tf.name)
drs.copy(self.drs_url_requester_pays, tf.name)
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
with self.subTest("Copy to bucket"):
enable_requester_pays.reset_mock()
drs.copy(self.drs_url, "gs://some_bucket/some_key")
drs.copy(self.drs_url_requester_pays, "gs://some_bucket/some_key")
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
with self.subTest("Copy batch urls"):
enable_requester_pays.reset_mock()
with tempfile.TemporaryDirectory() as td_name:
drs.copy_batch_urls([self.drs_url, self.drs_url_requester_pays], td_name)
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
with self.subTest("Copy batch manifest"):
enable_requester_pays.reset_mock()
with tempfile.TemporaryDirectory() as td_name:
manifest = [{"drs_uri": self.drs_url, "dst": td_name},
{"drs_uri": self.drs_url_requester_pays, "dst": td_name}]
drs.copy_batch_manifest(manifest)
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
with self.subTest("Extract tarball"):
enable_requester_pays.reset_mock()
drs.extract_tar_gz(self.drs_url)
drs.extract_tar_gz(self.drs_url_requester_pays)
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
with self.subTest("Head"):
enable_requester_pays.reset_mock()
drs.head(self.drs_url_requester_pays)
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)

@testmode("workspace_access")
def test_enable_requester_pays_not_called_when_not_necessary(self):
resp_json = mock.MagicMock(return_value={
'googleServiceAccount': {'data': {'project_id': "foo"}},
'dos': {'data_object': {'urls': [{'url': 'gs://asdf/asdf'}]}}
})
requests_post = mock.MagicMock(return_value=mock.MagicMock(status_code=200, json=resp_json))
with ExitStack() as es:
es.enter_context(mock.patch("terra_notebook_utils.drs.gs.get_client"))
es.enter_context(mock.patch("terra_notebook_utils.drs.tar_gz"))
es.enter_context(mock.patch("terra_notebook_utils.blobstore.gs.GSBlob.download"))
es.enter_context(mock.patch("terra_notebook_utils.drs.DRSCopyClient"))
es.enter_context(mock.patch("terra_notebook_utils.drs.GSBlob.open"))
es.enter_context(mock.patch("terra_notebook_utils.drs.http", post=requests_post))
with mock.patch("terra_notebook_utils.drs.enable_requester_pays") as enable_requester_pays:
with self.subTest("Access URL"):
try:
drs.access(self.drs_url)
except Exception:
pass # Ignore downstream error due to complexity of mocking
enable_requester_pays.assert_not_called()
with self.subTest("Copy to local"):
enable_requester_pays.reset_mock()
with tempfile.NamedTemporaryFile() as tf:
drs.copy(self.drs_url, tf.name)
enable_requester_pays.assert_not_called()
with self.subTest("Copy to bucket"):
enable_requester_pays.reset_mock()
drs.copy(self.drs_url, "gs://some_bucket/some_key")
enable_requester_pays.assert_not_called()
with self.subTest("Copy batch urls"):
enable_requester_pays.reset_mock()
with tempfile.TemporaryDirectory() as td_name:
drs.copy_batch_urls([self.drs_url, self.drs_url], td_name)
enable_requester_pays.assert_not_called()
with self.subTest("Copy batch manifest"):
enable_requester_pays.reset_mock()
with tempfile.TemporaryDirectory() as td_name:
manifest = [{"drs_uri": self.drs_url, "dst": td_name},
{"drs_uri": self.drs_url, "dst": td_name}]
drs.copy_batch_manifest(manifest)
enable_requester_pays.assert_not_called()
with self.subTest("Extract tarball"):
enable_requester_pays.reset_mock()
drs.extract_tar_gz(self.drs_url)
enable_requester_pays.assert_not_called()
with self.subTest("Head"):
enable_requester_pays.reset_mock()
drs.head(self.drs_url)
enable_requester_pays.assert_not_called()

# test for when we get everything what we wanted in drs_resolver response
def test_drs_resolver_response(self):
resp_json = mock.MagicMock(return_value=self.mock_jdr_response)
Expand Down

0 comments on commit 70b55c9

Please sign in to comment.