From 70b55c9652a2f11a54be375ee0e737c76b69eed7 Mon Sep 17 00:00:00 2001 From: Michael Baumann <84099036+mbaumann-broad@users.noreply.github.com> Date: Tue, 24 Jan 2023 17:22:11 -0800 Subject: [PATCH] Only enable requester pays when necessary #400 (#402) Only enable (Google) requester pays data access when the given DRS URIs require it and the platform TNU is running on supports it. --- terra_notebook_utils/drs.py | 72 +++++++++++++++++++++----- tests/test_drs.py | 100 +++++++++++++++++++++++++++++++++--- 2 files changed, 154 insertions(+), 18 deletions(-) diff --git a/terra_notebook_utils/drs.py b/terra_notebook_utils/drs.py index d8d4e70..b611712 100644 --- a/terra_notebook_utils/drs.py +++ b/terra_notebook_utils/drs.py @@ -9,8 +9,8 @@ from terra_notebook_utils import WORKSPACE_BUCKET, WORKSPACE_NAME, DRS_RESOLVER_URL, WORKSPACE_NAMESPACE, \ WORKSPACE_GOOGLE_PROJECT -from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA -from terra_notebook_utils.utils import is_notebook +from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA, ExecutionPlatform +from terra_notebook_utils.utils import is_notebook, get_execution_context from terra_notebook_utils.http import http from terra_notebook_utils.blobstore.gs import GSBlob from terra_notebook_utils.blobstore.local import LocalBlob @@ -26,6 +26,9 @@ class DRSResolutionError(Exception): pass +class RequesterPaysNotSupported(Exception): + pass + def _parse_gs_url(gs_url: str) -> Tuple[str, str]: if gs_url.startswith(_GS_SCHEMA): bucket_name, object_key = gs_url[len(_GS_SCHEMA):].split("/", 1) @@ -33,9 +36,41 @@ def _parse_gs_url(gs_url: str) -> Tuple[str, str]: else: raise RuntimeError(f'Invalid gs url schema. {gs_url} does not start with {_GS_SCHEMA}') + +def is_requester_pays(drs_urls: Iterable[str]) -> bool: + """ + Identify if any of the given DRS URIs require Google requester pays access + + :raises: RequesterPaysNotSupported + """ + for drs_url in drs_urls: + # Currently (1/2023), Gen3-hosted AnVIL data in GCS is the only + # DRS data requiring requester pays access. + # Even this will end in when the AnVIL data is hosted in TDR. + # Note: If the Gen3 AnVIL DRS URI format is retained by TDR, this function must be updated. + + # The DRS specification (v1.2) states compact identifiers must be lowercase alphanumerical values: + # https://ga4gh.github.io/data-repository-service-schemas/preview/release/drs-1.2.0/docs/#section/DRS-URIs + # Yet for historical reasons the DRS URIs minted by Gen3 use uppercase. + # Perform a case-insensitive comparison. + anvil_drs_uri_prefix = "drs://dg.ANV0".lower() + if drs_url.strip().lower().startswith(anvil_drs_uri_prefix): + if get_execution_context().execution_platform == ExecutionPlatform.AZURE: + raise RequesterPaysNotSupported( + f"Requester pays data access is not supported on the Azure platform. Cannot access: {drs_url}" + ) + else: + return True + return False + + @lru_cache() def enable_requester_pays(workspace_name: Optional[str]=WORKSPACE_NAME, workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE): + + assert get_execution_context().execution_platform != ExecutionPlatform.AZURE, \ + "Requester pays data access is not supported on the Terra Azure platform." + if not workspace_name: raise RuntimeError('Workspace name is not set. Please set the environment variable ' 'WORKSPACE_NAME with the name of a valid Terra Workspace.') @@ -92,11 +127,14 @@ def access(drs_url: str, workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE, billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT) -> str: """Return a signed url for a drs:// URI, if available.""" - # We enable requester pays by specifying the workspace/namespace combo, not - # with the billing project. Rawls then enables requester pays for the attached - # project, but this won't work if a user specifies a project unattached to - # the Terra workspace. - enable_requester_pays(workspace_name, workspace_namespace) + + if is_requester_pays([drs_url]): + # We enable requester pays by specifying the workspace/namespace combo, not + # with the billing project. Rawls then enables requester pays for the attached + # project, but this won't work if a user specifies a project unattached to + # the Terra workspace. + enable_requester_pays(workspace_name, workspace_namespace) + info = get_drs_info(drs_url, access_url=True) if info.access_url: @@ -232,7 +270,9 @@ def head(drs_url: str, workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE, billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT): """Head a DRS object by byte.""" - enable_requester_pays(workspace_name, workspace_namespace) + + if is_requester_pays([drs_url]): + enable_requester_pays(workspace_name, workspace_namespace) try: blob = get_drs_blob(drs_url, billing_project) with blob.open(chunk_size=num_bytes) as fh: @@ -295,7 +335,8 @@ def copy(drs_uri: str, """Copy a DRS object to either the local filesystem, or to a Google Storage location if `dst` starts with "gs://". """ - enable_requester_pays(workspace_name, workspace_namespace) + if is_requester_pays([drs_uri]): + enable_requester_pays(workspace_name, workspace_namespace) with DRSCopyClient(raise_on_error=True, indicator_type=indicator_type) as cc: cc.workspace = workspace_name cc.workspace_namespace = workspace_namespace @@ -340,7 +381,8 @@ def copy_batch_urls(drs_urls: Iterable[str], indicator_type: Indicator = Indicator.notebook_bar if is_notebook() else Indicator.log, workspace_name: Optional[str] = WORKSPACE_NAME, workspace_namespace: Optional[str] = WORKSPACE_NAMESPACE): - enable_requester_pays(workspace_name, workspace_namespace) + if is_requester_pays(drs_urls): + enable_requester_pays(workspace_name, workspace_namespace) with DRSCopyClient(indicator_type=indicator_type) as cc: cc.workspace = workspace_name cc.workspace_namespace = workspace_namespace @@ -365,7 +407,11 @@ def copy_batch_manifest(manifest: List[Dict[str, str]], workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE): from jsonschema import validate validate(instance=manifest, schema=manifest_schema) - enable_requester_pays(workspace_name, workspace_namespace) + + drs_uri_list = [item['drs_uri'] for item in manifest] + if is_requester_pays(drs_uri_list): + enable_requester_pays(workspace_name, workspace_namespace) + with DRSCopyClient(indicator_type=indicator_type) as cc: cc.workspace = workspace_name cc.workspace_namespace = workspace_namespace @@ -381,7 +427,9 @@ def extract_tar_gz(drs_url: str, Default extraction is to the bucket for 'workspace'. """ dst = dst or f"gs://{workspace.get_workspace_bucket(workspace_name)}" - enable_requester_pays(workspace_name, workspace_namespace) + + if is_requester_pays([drs_url]): + enable_requester_pays(workspace_name, workspace_namespace) blob = get_drs_blob(drs_url, billing_project) with blob.open() as fh: tar_gz.extract(fh, dst) diff --git a/tests/test_drs.py b/tests/test_drs.py index 6833ee0..55f0281 100644 --- a/tests/test_drs.py +++ b/tests/test_drs.py @@ -69,8 +69,14 @@ def test_copy_to_bucket(self): class TestTerraNotebookUtilsDRS(SuppressWarningsMixin, unittest.TestCase): - drs_url = "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35" # DRS response contains GS native url + credentials - drs_url_signed = "drs://dg.4DFC:dg.4DFC/00040a6f-b7e5-4e5c-ab57-ee92a0ba8201" # DRS response contains signed URL + # BDC data on Google + # DRS response contains GS native url + credentials + signed URL if requested + drs_url = "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35" + # CRDC PDC data on AWS (only) + # DRS response contains signed URL if requested + drs_url_signed = "drs://dg.4DFC:dg.4DFC/00040a6f-b7e5-4e5c-ab57-ee92a0ba8201" + # DRS response contains GS native url + credentials (no signed URL) + drs_url_requester_pays = "drs://dg.ANV0/1b1ee6fc-6560-4b08-9c44-36d46bf4daf1" jade_dev_url = "drs://jade.datarepo-dev.broadinstitute.org/v1_0c86170e-312d-4b39-a0a4-2a2bfaa24c7a_" \ "c0e40912-8b14-43f6-9a2f-b278144d0060" @@ -396,7 +402,15 @@ def test_extract_tar_gz(self): self.assertEqual(data, expected_data) @testmode("workspace_access") - def test_arg_propagation(self): + def test_is_requester_pays(self): + self.assertFalse(drs.is_requester_pays([self.drs_url])) + self.assertFalse(drs.is_requester_pays([self.drs_url_signed])) + self.assertTrue(drs.is_requester_pays([self.drs_url_requester_pays])) + self.assertTrue(drs.is_requester_pays( + [self.drs_url, self.drs_url_requester_pays, self.drs_url_signed])) + + @testmode("workspace_access") + def test_arg_propagation_and_enable_requester_pays(self): resp_json = mock.MagicMock(return_value={ 'googleServiceAccount': {'data': {'project_id': "foo"}}, 'dos': {'data_object': {'urls': [{'url': 'gs://asdf/asdf'}]}} @@ -410,19 +424,93 @@ def test_arg_propagation(self): es.enter_context(mock.patch("terra_notebook_utils.drs.GSBlob.open")) es.enter_context(mock.patch("terra_notebook_utils.drs.http", post=requests_post)) with mock.patch("terra_notebook_utils.drs.enable_requester_pays") as enable_requester_pays: + with self.subTest("Access URL"): + try: + drs.access(self.drs_url_requester_pays) + except Exception: + pass # Ignore downstream error due to complexity of mocking + enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE) with self.subTest("Copy to local"): + enable_requester_pays.reset_mock() with tempfile.NamedTemporaryFile() as tf: - drs.copy(self.drs_url, tf.name) + drs.copy(self.drs_url_requester_pays, tf.name) enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE) with self.subTest("Copy to bucket"): enable_requester_pays.reset_mock() - drs.copy(self.drs_url, "gs://some_bucket/some_key") + drs.copy(self.drs_url_requester_pays, "gs://some_bucket/some_key") enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE) + with self.subTest("Copy batch urls"): + enable_requester_pays.reset_mock() + with tempfile.TemporaryDirectory() as td_name: + drs.copy_batch_urls([self.drs_url, self.drs_url_requester_pays], td_name) + enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE) + with self.subTest("Copy batch manifest"): + enable_requester_pays.reset_mock() + with tempfile.TemporaryDirectory() as td_name: + manifest = [{"drs_uri": self.drs_url, "dst": td_name}, + {"drs_uri": self.drs_url_requester_pays, "dst": td_name}] + drs.copy_batch_manifest(manifest) + enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE) with self.subTest("Extract tarball"): enable_requester_pays.reset_mock() - drs.extract_tar_gz(self.drs_url) + drs.extract_tar_gz(self.drs_url_requester_pays) + enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE) + with self.subTest("Head"): + enable_requester_pays.reset_mock() + drs.head(self.drs_url_requester_pays) enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE) + @testmode("workspace_access") + def test_enable_requester_pays_not_called_when_not_necessary(self): + resp_json = mock.MagicMock(return_value={ + 'googleServiceAccount': {'data': {'project_id': "foo"}}, + 'dos': {'data_object': {'urls': [{'url': 'gs://asdf/asdf'}]}} + }) + requests_post = mock.MagicMock(return_value=mock.MagicMock(status_code=200, json=resp_json)) + with ExitStack() as es: + es.enter_context(mock.patch("terra_notebook_utils.drs.gs.get_client")) + es.enter_context(mock.patch("terra_notebook_utils.drs.tar_gz")) + es.enter_context(mock.patch("terra_notebook_utils.blobstore.gs.GSBlob.download")) + es.enter_context(mock.patch("terra_notebook_utils.drs.DRSCopyClient")) + es.enter_context(mock.patch("terra_notebook_utils.drs.GSBlob.open")) + es.enter_context(mock.patch("terra_notebook_utils.drs.http", post=requests_post)) + with mock.patch("terra_notebook_utils.drs.enable_requester_pays") as enable_requester_pays: + with self.subTest("Access URL"): + try: + drs.access(self.drs_url) + except Exception: + pass # Ignore downstream error due to complexity of mocking + enable_requester_pays.assert_not_called() + with self.subTest("Copy to local"): + enable_requester_pays.reset_mock() + with tempfile.NamedTemporaryFile() as tf: + drs.copy(self.drs_url, tf.name) + enable_requester_pays.assert_not_called() + with self.subTest("Copy to bucket"): + enable_requester_pays.reset_mock() + drs.copy(self.drs_url, "gs://some_bucket/some_key") + enable_requester_pays.assert_not_called() + with self.subTest("Copy batch urls"): + enable_requester_pays.reset_mock() + with tempfile.TemporaryDirectory() as td_name: + drs.copy_batch_urls([self.drs_url, self.drs_url], td_name) + enable_requester_pays.assert_not_called() + with self.subTest("Copy batch manifest"): + enable_requester_pays.reset_mock() + with tempfile.TemporaryDirectory() as td_name: + manifest = [{"drs_uri": self.drs_url, "dst": td_name}, + {"drs_uri": self.drs_url, "dst": td_name}] + drs.copy_batch_manifest(manifest) + enable_requester_pays.assert_not_called() + with self.subTest("Extract tarball"): + enable_requester_pays.reset_mock() + drs.extract_tar_gz(self.drs_url) + enable_requester_pays.assert_not_called() + with self.subTest("Head"): + enable_requester_pays.reset_mock() + drs.head(self.drs_url) + enable_requester_pays.assert_not_called() + # test for when we get everything what we wanted in drs_resolver response def test_drs_resolver_response(self): resp_json = mock.MagicMock(return_value=self.mock_jdr_response)