Skip to content

Commit 81d2df4

Browse files
author
Michael Baumann
committed
Add drs tests for enable_requester_pays only called when necessary
1 parent 9b89b76 commit 81d2df4

File tree

1 file changed

+62
-3
lines changed

1 file changed

+62
-3
lines changed

tests/test_drs.py

+62-3
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,15 @@ def test_extract_tar_gz(self):
402402
self.assertEqual(data, expected_data)
403403

404404
@testmode("workspace_access")
405-
def test_arg_propagation(self):
405+
def test_is_requester_pays(self):
406+
self.assertFalse(drs.is_requester_pays([self.drs_url]))
407+
self.assertFalse(drs.is_requester_pays([self.drs_url_signed]))
408+
self.assertTrue(drs.is_requester_pays([self.drs_url_requester_pays]))
409+
self.assertTrue(drs.is_requester_pays(
410+
[self.drs_url, self.drs_url_requester_pays, self.drs_url_signed]))
411+
412+
@testmode("workspace_access")
413+
def test_arg_propagation_and_enable_requester_pays(self):
406414
resp_json = mock.MagicMock(return_value={
407415
'googleServiceAccount': {'data': {'project_id': "foo"}},
408416
'dos': {'data_object': {'urls': [{'url': 'gs://asdf/asdf'}]}}
@@ -419,7 +427,7 @@ def test_arg_propagation(self):
419427
with self.subTest("Access URL"):
420428
try:
421429
drs.access(self.drs_url_requester_pays)
422-
except:
430+
except Exception:
423431
pass # Ignore downstream error due to complexity of mocking
424432
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
425433
with self.subTest("Copy to local"):
@@ -439,7 +447,7 @@ def test_arg_propagation(self):
439447
with self.subTest("Copy batch manifest"):
440448
enable_requester_pays.reset_mock()
441449
with tempfile.TemporaryDirectory() as td_name:
442-
manifest = [{"drs_uri": self.drs_url, "dst": td_name}, \
450+
manifest = [{"drs_uri": self.drs_url, "dst": td_name},
443451
{"drs_uri": self.drs_url_requester_pays, "dst": td_name}]
444452
drs.copy_batch_manifest(manifest)
445453
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
@@ -452,6 +460,57 @@ def test_arg_propagation(self):
452460
drs.head(self.drs_url_requester_pays)
453461
enable_requester_pays.assert_called_with(WORKSPACE_NAME, WORKSPACE_NAMESPACE)
454462

463+
@testmode("workspace_access")
464+
def test_enable_requester_pays_not_called_when_not_necessary(self):
465+
resp_json = mock.MagicMock(return_value={
466+
'googleServiceAccount': {'data': {'project_id': "foo"}},
467+
'dos': {'data_object': {'urls': [{'url': 'gs://asdf/asdf'}]}}
468+
})
469+
requests_post = mock.MagicMock(return_value=mock.MagicMock(status_code=200, json=resp_json))
470+
with ExitStack() as es:
471+
es.enter_context(mock.patch("terra_notebook_utils.drs.gs.get_client"))
472+
es.enter_context(mock.patch("terra_notebook_utils.drs.tar_gz"))
473+
es.enter_context(mock.patch("terra_notebook_utils.blobstore.gs.GSBlob.download"))
474+
es.enter_context(mock.patch("terra_notebook_utils.drs.DRSCopyClient"))
475+
es.enter_context(mock.patch("terra_notebook_utils.drs.GSBlob.open"))
476+
es.enter_context(mock.patch("terra_notebook_utils.drs.http", post=requests_post))
477+
with mock.patch("terra_notebook_utils.drs.enable_requester_pays") as enable_requester_pays:
478+
with self.subTest("Access URL"):
479+
try:
480+
drs.access(self.drs_url)
481+
except Exception:
482+
pass # Ignore downstream error due to complexity of mocking
483+
enable_requester_pays.assert_not_called()
484+
with self.subTest("Copy to local"):
485+
enable_requester_pays.reset_mock()
486+
with tempfile.NamedTemporaryFile() as tf:
487+
drs.copy(self.drs_url, tf.name)
488+
enable_requester_pays.assert_not_called()
489+
with self.subTest("Copy to bucket"):
490+
enable_requester_pays.reset_mock()
491+
drs.copy(self.drs_url, "gs://some_bucket/some_key")
492+
enable_requester_pays.assert_not_called()
493+
with self.subTest("Copy batch urls"):
494+
enable_requester_pays.reset_mock()
495+
with tempfile.TemporaryDirectory() as td_name:
496+
drs.copy_batch_urls([self.drs_url, self.drs_url], td_name)
497+
enable_requester_pays.assert_not_called()
498+
with self.subTest("Copy batch manifest"):
499+
enable_requester_pays.reset_mock()
500+
with tempfile.TemporaryDirectory() as td_name:
501+
manifest = [{"drs_uri": self.drs_url, "dst": td_name},
502+
{"drs_uri": self.drs_url, "dst": td_name}]
503+
drs.copy_batch_manifest(manifest)
504+
enable_requester_pays.assert_not_called()
505+
with self.subTest("Extract tarball"):
506+
enable_requester_pays.reset_mock()
507+
drs.extract_tar_gz(self.drs_url)
508+
enable_requester_pays.assert_not_called()
509+
with self.subTest("Head"):
510+
enable_requester_pays.reset_mock()
511+
drs.head(self.drs_url)
512+
enable_requester_pays.assert_not_called()
513+
455514
# test for when we get everything what we wanted in drs_resolver response
456515
def test_drs_resolver_response(self):
457516
resp_json = mock.MagicMock(return_value=self.mock_jdr_response)

0 commit comments

Comments
 (0)