@@ -69,8 +69,14 @@ def test_copy_to_bucket(self):
6969
7070
7171class TestTerraNotebookUtilsDRS (SuppressWarningsMixin , unittest .TestCase ):
72- drs_url = "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35" # DRS response contains GS native url + credentials
73- drs_url_signed = "drs://dg.4DFC:dg.4DFC/00040a6f-b7e5-4e5c-ab57-ee92a0ba8201" # DRS response contains signed URL
72+ # BDC data on Google
73+ # DRS response contains GS native url + credentials + signed URL if requested
74+ drs_url = "drs://dg.4503/95cc4ae1-dee7-4266-8b97-77cf46d83d35"
75+ # CRDC PDC data on AWS (only)
76+ # DRS response contains signed URL if requested
77+ drs_url_signed = "drs://dg.4DFC:dg.4DFC/00040a6f-b7e5-4e5c-ab57-ee92a0ba8201"
78+ # DRS response contains GS native url + credentials (no signed URL)
79+ drs_url_requester_pays = "drs://dg.ANV0/1b1ee6fc-6560-4b08-9c44-36d46bf4daf1"
7480 jade_dev_url = "drs://jade.datarepo-dev.broadinstitute.org/v1_0c86170e-312d-4b39-a0a4-2a2bfaa24c7a_" \
7581 "c0e40912-8b14-43f6-9a2f-b278144d0060"
7682
@@ -396,7 +402,15 @@ def test_extract_tar_gz(self):
396402 self .assertEqual (data , expected_data )
397403
398404 @testmode ("workspace_access" )
399- def test_arg_propagation (self ):
405+ def test_is_requester_pays (self ):
406+ self .assertFalse (drs .is_requester_pays ([self .drs_url ]))
407+ self .assertFalse (drs .is_requester_pays ([self .drs_url_signed ]))
408+ self .assertTrue (drs .is_requester_pays ([self .drs_url_requester_pays ]))
409+ self .assertTrue (drs .is_requester_pays (
410+ [self .drs_url , self .drs_url_requester_pays , self .drs_url_signed ]))
411+
412+ @testmode ("workspace_access" )
413+ def test_arg_propagation_and_enable_requester_pays (self ):
400414 resp_json = mock .MagicMock (return_value = {
401415 'googleServiceAccount' : {'data' : {'project_id' : "foo" }},
402416 'dos' : {'data_object' : {'urls' : [{'url' : 'gs://asdf/asdf' }]}}
@@ -410,19 +424,93 @@ def test_arg_propagation(self):
410424 es .enter_context (mock .patch ("terra_notebook_utils.drs.GSBlob.open" ))
411425 es .enter_context (mock .patch ("terra_notebook_utils.drs.http" , post = requests_post ))
412426 with mock .patch ("terra_notebook_utils.drs.enable_requester_pays" ) as enable_requester_pays :
427+ with self .subTest ("Access URL" ):
428+ try :
429+ drs .access (self .drs_url_requester_pays )
430+ except Exception :
431+ pass # Ignore downstream error due to complexity of mocking
432+ enable_requester_pays .assert_called_with (WORKSPACE_NAME , WORKSPACE_NAMESPACE )
413433 with self .subTest ("Copy to local" ):
434+ enable_requester_pays .reset_mock ()
414435 with tempfile .NamedTemporaryFile () as tf :
415- drs .copy (self .drs_url , tf .name )
436+ drs .copy (self .drs_url_requester_pays , tf .name )
416437 enable_requester_pays .assert_called_with (WORKSPACE_NAME , WORKSPACE_NAMESPACE )
417438 with self .subTest ("Copy to bucket" ):
418439 enable_requester_pays .reset_mock ()
419- drs .copy (self .drs_url , "gs://some_bucket/some_key" )
440+ drs .copy (self .drs_url_requester_pays , "gs://some_bucket/some_key" )
420441 enable_requester_pays .assert_called_with (WORKSPACE_NAME , WORKSPACE_NAMESPACE )
442+ with self .subTest ("Copy batch urls" ):
443+ enable_requester_pays .reset_mock ()
444+ with tempfile .TemporaryDirectory () as td_name :
445+ drs .copy_batch_urls ([self .drs_url , self .drs_url_requester_pays ], td_name )
446+ enable_requester_pays .assert_called_with (WORKSPACE_NAME , WORKSPACE_NAMESPACE )
447+ with self .subTest ("Copy batch manifest" ):
448+ enable_requester_pays .reset_mock ()
449+ with tempfile .TemporaryDirectory () as td_name :
450+ manifest = [{"drs_uri" : self .drs_url , "dst" : td_name },
451+ {"drs_uri" : self .drs_url_requester_pays , "dst" : td_name }]
452+ drs .copy_batch_manifest (manifest )
453+ enable_requester_pays .assert_called_with (WORKSPACE_NAME , WORKSPACE_NAMESPACE )
421454 with self .subTest ("Extract tarball" ):
422455 enable_requester_pays .reset_mock ()
423- drs .extract_tar_gz (self .drs_url )
456+ drs .extract_tar_gz (self .drs_url_requester_pays )
457+ enable_requester_pays .assert_called_with (WORKSPACE_NAME , WORKSPACE_NAMESPACE )
458+ with self .subTest ("Head" ):
459+ enable_requester_pays .reset_mock ()
460+ drs .head (self .drs_url_requester_pays )
424461 enable_requester_pays .assert_called_with (WORKSPACE_NAME , WORKSPACE_NAMESPACE )
425462
463+ @testmode ("workspace_access" )
464+ def test_enable_requester_pays_not_called_when_not_necessary (self ):
465+ resp_json = mock .MagicMock (return_value = {
466+ 'googleServiceAccount' : {'data' : {'project_id' : "foo" }},
467+ 'dos' : {'data_object' : {'urls' : [{'url' : 'gs://asdf/asdf' }]}}
468+ })
469+ requests_post = mock .MagicMock (return_value = mock .MagicMock (status_code = 200 , json = resp_json ))
470+ with ExitStack () as es :
471+ es .enter_context (mock .patch ("terra_notebook_utils.drs.gs.get_client" ))
472+ es .enter_context (mock .patch ("terra_notebook_utils.drs.tar_gz" ))
473+ es .enter_context (mock .patch ("terra_notebook_utils.blobstore.gs.GSBlob.download" ))
474+ es .enter_context (mock .patch ("terra_notebook_utils.drs.DRSCopyClient" ))
475+ es .enter_context (mock .patch ("terra_notebook_utils.drs.GSBlob.open" ))
476+ es .enter_context (mock .patch ("terra_notebook_utils.drs.http" , post = requests_post ))
477+ with mock .patch ("terra_notebook_utils.drs.enable_requester_pays" ) as enable_requester_pays :
478+ with self .subTest ("Access URL" ):
479+ try :
480+ drs .access (self .drs_url )
481+ except Exception :
482+ pass # Ignore downstream error due to complexity of mocking
483+ enable_requester_pays .assert_not_called ()
484+ with self .subTest ("Copy to local" ):
485+ enable_requester_pays .reset_mock ()
486+ with tempfile .NamedTemporaryFile () as tf :
487+ drs .copy (self .drs_url , tf .name )
488+ enable_requester_pays .assert_not_called ()
489+ with self .subTest ("Copy to bucket" ):
490+ enable_requester_pays .reset_mock ()
491+ drs .copy (self .drs_url , "gs://some_bucket/some_key" )
492+ enable_requester_pays .assert_not_called ()
493+ with self .subTest ("Copy batch urls" ):
494+ enable_requester_pays .reset_mock ()
495+ with tempfile .TemporaryDirectory () as td_name :
496+ drs .copy_batch_urls ([self .drs_url , self .drs_url ], td_name )
497+ enable_requester_pays .assert_not_called ()
498+ with self .subTest ("Copy batch manifest" ):
499+ enable_requester_pays .reset_mock ()
500+ with tempfile .TemporaryDirectory () as td_name :
501+ manifest = [{"drs_uri" : self .drs_url , "dst" : td_name },
502+ {"drs_uri" : self .drs_url , "dst" : td_name }]
503+ drs .copy_batch_manifest (manifest )
504+ enable_requester_pays .assert_not_called ()
505+ with self .subTest ("Extract tarball" ):
506+ enable_requester_pays .reset_mock ()
507+ drs .extract_tar_gz (self .drs_url )
508+ enable_requester_pays .assert_not_called ()
509+ with self .subTest ("Head" ):
510+ enable_requester_pays .reset_mock ()
511+ drs .head (self .drs_url )
512+ enable_requester_pays .assert_not_called ()
513+
426514 # test for when we get everything what we wanted in drs_resolver response
427515 def test_drs_resolver_response (self ):
428516 resp_json = mock .MagicMock (return_value = self .mock_jdr_response )
0 commit comments