@@ -402,7 +402,15 @@ def test_extract_tar_gz(self):
402
402
self .assertEqual (data , expected_data )
403
403
404
404
@testmode ("workspace_access" )
405
- def test_arg_propagation (self ):
405
+ def test_is_requester_pays (self ):
406
+ self .assertFalse (drs .is_requester_pays ([self .drs_url ]))
407
+ self .assertFalse (drs .is_requester_pays ([self .drs_url_signed ]))
408
+ self .assertTrue (drs .is_requester_pays ([self .drs_url_requester_pays ]))
409
+ self .assertTrue (drs .is_requester_pays (\
410
+ [self .drs_url , self .drs_url_requester_pays , self .drs_url_signed ]))
411
+
412
+ @testmode ("workspace_access" )
413
+ def test_arg_propagation_and_enable_requester_pays (self ):
406
414
resp_json = mock .MagicMock (return_value = {
407
415
'googleServiceAccount' : {'data' : {'project_id' : "foo" }},
408
416
'dos' : {'data_object' : {'urls' : [{'url' : 'gs://asdf/asdf' }]}}
@@ -452,6 +460,57 @@ def test_arg_propagation(self):
452
460
drs .head (self .drs_url_requester_pays )
453
461
enable_requester_pays .assert_called_with (WORKSPACE_NAME , WORKSPACE_NAMESPACE )
454
462
463
+ @testmode ("workspace_access" )
464
+ def test_enable_requester_pays_not_called_when_not_necessary (self ):
465
+ resp_json = mock .MagicMock (return_value = {
466
+ 'googleServiceAccount' : {'data' : {'project_id' : "foo" }},
467
+ 'dos' : {'data_object' : {'urls' : [{'url' : 'gs://asdf/asdf' }]}}
468
+ })
469
+ requests_post = mock .MagicMock (return_value = mock .MagicMock (status_code = 200 , json = resp_json ))
470
+ with ExitStack () as es :
471
+ es .enter_context (mock .patch ("terra_notebook_utils.drs.gs.get_client" ))
472
+ es .enter_context (mock .patch ("terra_notebook_utils.drs.tar_gz" ))
473
+ es .enter_context (mock .patch ("terra_notebook_utils.blobstore.gs.GSBlob.download" ))
474
+ es .enter_context (mock .patch ("terra_notebook_utils.drs.DRSCopyClient" ))
475
+ es .enter_context (mock .patch ("terra_notebook_utils.drs.GSBlob.open" ))
476
+ es .enter_context (mock .patch ("terra_notebook_utils.drs.http" , post = requests_post ))
477
+ with mock .patch ("terra_notebook_utils.drs.enable_requester_pays" ) as enable_requester_pays :
478
+ with self .subTest ("Access URL" ):
479
+ try :
480
+ drs .access (self .drs_url )
481
+ except :
482
+ pass # Ignore downstream error due to complexity of mocking
483
+ enable_requester_pays .assert_not_called ()
484
+ with self .subTest ("Copy to local" ):
485
+ enable_requester_pays .reset_mock ()
486
+ with tempfile .NamedTemporaryFile () as tf :
487
+ drs .copy (self .drs_url , tf .name )
488
+ enable_requester_pays .assert_not_called ()
489
+ with self .subTest ("Copy to bucket" ):
490
+ enable_requester_pays .reset_mock ()
491
+ drs .copy (self .drs_url , "gs://some_bucket/some_key" )
492
+ enable_requester_pays .assert_not_called ()
493
+ with self .subTest ("Copy batch urls" ):
494
+ enable_requester_pays .reset_mock ()
495
+ with tempfile .TemporaryDirectory () as td_name :
496
+ drs .copy_batch_urls ([self .drs_url , self .drs_url ], td_name )
497
+ enable_requester_pays .assert_not_called ()
498
+ with self .subTest ("Copy batch manifest" ):
499
+ enable_requester_pays .reset_mock ()
500
+ with tempfile .TemporaryDirectory () as td_name :
501
+ manifest = [{"drs_uri" : self .drs_url , "dst" : td_name }, \
502
+ {"drs_uri" : self .drs_url , "dst" : td_name }]
503
+ drs .copy_batch_manifest (manifest )
504
+ enable_requester_pays .assert_not_called ()
505
+ with self .subTest ("Extract tarball" ):
506
+ enable_requester_pays .reset_mock ()
507
+ drs .extract_tar_gz (self .drs_url )
508
+ enable_requester_pays .assert_not_called ()
509
+ with self .subTest ("Head" ):
510
+ enable_requester_pays .reset_mock ()
511
+ drs .head (self .drs_url )
512
+ enable_requester_pays .assert_not_called ()
513
+
455
514
# test for when we get everything what we wanted in drs_resolver response
456
515
def test_drs_resolver_response (self ):
457
516
resp_json = mock .MagicMock (return_value = self .mock_jdr_response )
0 commit comments