Skip to content

Commit c4054aa

Browse files
author
Michael Baumann
committed
Add drs tests for enable_requester_pays only called when necessary
1 parent 9b89b76 commit c4054aa

File tree

1 file changed

+60
-1
lines changed

1 file changed

+60
-1
lines changed

tests/test_drs.py

+60-1
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,15 @@ def test_extract_tar_gz(self):
402402
self.assertEqual(data, expected_data)
403403

404404
@testmode("workspace_access")
405-
def test_arg_propagation(self):
405+
def test_is_requester_pays(self):
406+
self.assertFalse(drs.is_requester_pays([self.drs_url]))
407+
self.assertFalse(drs.is_requester_pays([self.drs_url_signed]))
408+
self.assertTrue(drs.is_requester_pays([self.drs_url_requester_pays]))
409+
self.assertTrue(drs.is_requester_pays(\
410+
[self.drs_url, self.drs_url_requester_pays, self.drs_url_signed]))
411+
412+
@testmode("workspace_access")
413+
def test_arg_propagation_and_enable_requester_pays(self):
406414
resp_json = mock.MagicMock(return_value={
407415
'googleServiceAccount': {'data': {'project_id': "foo"}},
408416
'dos': {'data_object': {'urls': [{'url': 'gs://asdf/asdf'}]}}
@@ -452,6 +460,57 @@ def test_arg_propagation(self):
452460
drs.head(self.drs_url_requester_pays)
453461
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
454462

463+
@testmode("workspace_access")
464+
def test_enable_requester_pays_not_called_when_not_necessary(self):
465+
resp_json = mock.MagicMock(return_value={
466+
'googleServiceAccount': {'data': {'project_id': "foo"}},
467+
'dos': {'data_object': {'urls': [{'url': 'gs://asdf/asdf'}]}}
468+
})
469+
requests_post = mock.MagicMock(return_value=mock.MagicMock(status_code=200, json=resp_json))
470+
with ExitStack() as es:
471+
es.enter_context(mock.patch("terra_notebook_utils.drs.gs.get_client"))
472+
es.enter_context(mock.patch("terra_notebook_utils.drs.tar_gz"))
473+
es.enter_context(mock.patch("terra_notebook_utils.blobstore.gs.GSBlob.download"))
474+
es.enter_context(mock.patch("terra_notebook_utils.drs.DRSCopyClient"))
475+
es.enter_context(mock.patch("terra_notebook_utils.drs.GSBlob.open"))
476+
es.enter_context(mock.patch("terra_notebook_utils.drs.http", post=requests_post))
477+
with mock.patch("terra_notebook_utils.drs.enable_requester_pays") as enable_requester_pays:
478+
with self.subTest("Access URL"):
479+
try:
480+
drs.access(self.drs_url)
481+
except:
482+
pass # Ignore downstream error due to complexity of mocking
483+
enable_requester_pays.assert_not_called()
484+
with self.subTest("Copy to local"):
485+
enable_requester_pays.reset_mock()
486+
with tempfile.NamedTemporaryFile() as tf:
487+
drs.copy(self.drs_url, tf.name)
488+
enable_requester_pays.assert_not_called()
489+
with self.subTest("Copy to bucket"):
490+
enable_requester_pays.reset_mock()
491+
drs.copy(self.drs_url, "gs://some_bucket/some_key")
492+
enable_requester_pays.assert_not_called()
493+
with self.subTest("Copy batch urls"):
494+
enable_requester_pays.reset_mock()
495+
with tempfile.TemporaryDirectory() as td_name:
496+
drs.copy_batch_urls([self.drs_url, self.drs_url], td_name)
497+
enable_requester_pays.assert_not_called()
498+
with self.subTest("Copy batch manifest"):
499+
enable_requester_pays.reset_mock()
500+
with tempfile.TemporaryDirectory() as td_name:
501+
manifest = [{"drs_uri": self.drs_url, "dst": td_name}, \
502+
{"drs_uri": self.drs_url, "dst": td_name}]
503+
drs.copy_batch_manifest(manifest)
504+
enable_requester_pays.assert_not_called()
505+
with self.subTest("Extract tarball"):
506+
enable_requester_pays.reset_mock()
507+
drs.extract_tar_gz(self.drs_url)
508+
enable_requester_pays.assert_not_called()
509+
with self.subTest("Head"):
510+
enable_requester_pays.reset_mock()
511+
drs.head(self.drs_url)
512+
enable_requester_pays.assert_not_called()
513+
455514
# test for when we get everything what we wanted in drs_resolver response
456515
def test_drs_resolver_response(self):
457516
resp_json = mock.MagicMock(return_value=self.mock_jdr_response)

0 commit comments

Comments
 (0)