1
+ import os
2
+ import time
3
+ import base64
4
+ import hashlib
5
+ import requests
6
+
7
+
8
+ class CloudDeepScan (object ):
9
+
10
+ def __init__ (self , token_endpoint , rest_hostname , client_id , client_secret ):
11
+ self .token_endpoint = token_endpoint
12
+ self .rest_hostname = rest_hostname
13
+ self .client_id = client_id
14
+ self .client_secret = client_secret
15
+
16
+ self .__token = None
17
+ self .__token_expires_at = None
18
+
19
+ def upload_sample (self , sample_path ):
20
+ file_name , file_size = self .__get_file_info (path = sample_path )
21
+ upload = self .__create_upload (file_name = file_name , file_size = file_size )
22
+ etags = self .__upload_parts (sample_path = sample_path , parts = upload ["parts" ])
23
+ self .__complete_upload (upload_id = upload ["upload_id" ], etags = etags , object_key = upload ["object_key" ])
24
+ return upload ["submission_id" ]
25
+
26
+ def download_report (self , content_hash , output_path ):
27
+ pass
28
+
29
+ def __do_rest_api_request (self , method , endpoint , body = None , params = None ):
30
+ token = self .__get_authorization_token ()
31
+ url = f"{ self .rest_hostname } { endpoint } "
32
+ headers = {
33
+ "Authorization" : f"Bearer { token } "
34
+ }
35
+ try :
36
+ response = requests .request (method , url , data = body , headers = headers , timeout = 10 )
37
+ except :
38
+ pass
39
+ return response
40
+
41
+ def __get_authorization_token (self ):
42
+ """Acquires access token via OAuth2.0 client credential flow:
43
+ https://www.rfc-editor.org/rfc/rfc6749#section-4.4
44
+ """
45
+ if not self .__is_token_valid ():
46
+ token_leeway = 20 # Add leeway to eliminate window where token may be expired on the server but not for us due to timings
47
+ try :
48
+ token_response = requests .post (self .token_endpoint , data = {"grant_type" : "client_credentials" }, auth = (self .client_id , self .client_secret ), timeout = 10 )
49
+ token_response .raise_for_status ()
50
+ token_data = token_response .json ()
51
+ except :
52
+ # TODO error handling
53
+ raise
54
+
55
+ try :
56
+ self .__token = token_data ["access_token" ]
57
+ self .__token_expires_at = time .time () + token_data ["expires_in" ] - token_leeway
58
+ except KeyError :
59
+ # TODO error handling
60
+ raise
61
+
62
+ return self .__token
63
+
64
+ def __is_token_valid (self ):
65
+ if self .__token_expires_at is None or self .__token is None :
66
+ return False
67
+ return time .time () > self .__token_expires_at
68
+
69
+ def __upload_parts (self , sample_path , parts ):
70
+ etags = []
71
+ with open (sample_path , "rb" ) as f :
72
+ for part in parts :
73
+ data = f .read (part ["content_length" ])
74
+ etag = self .__upload_part_to_s3 (url = part ["url" ], data = data )
75
+ etags .append (etag )
76
+ return etags
77
+
78
+ def __upload_part_to_s3 (self , url , data ):
79
+ content_hash = hashlib .md5 (data ).digest ()
80
+ encoded_hash = base64 .b64encode (content_hash )
81
+ headers = {
82
+ "Content-MD5" : encoded_hash .decode (),
83
+ "Content-Length" : len (body )
84
+ }
85
+ response = requests .put (url , data = data , timeout = 10 )
86
+ response .raise_for_status ()
87
+ return response .headers ["ETag" ]
88
+
89
+ def __get_file_info (self , path ):
90
+ stat_result = os .stat (path )
91
+ file_size = stat_result .st_size
92
+ file_name = os .path .basename (path )
93
+ return file_name , file_size
94
+
95
+ def __create_upload (self , file_name , file_size ):
96
+ response = self .__do_rest_api_request ("POST" , "/api/v1/uploads" , body = {"file_name" : file_name , "file_size" : file_size })
97
+ response .raise_for_status ()
98
+ return response .json ()
99
+
100
+ def __complete_upload (self , upload_id , etags , object_key ):
101
+ response = self .__do_rest_api_request ("PATCH" , f"/api/v1/uploads/{ upload_id } " , body = {"object_key" : object_key , "etags" : etags })
102
+ response .raise_for_status ()
103
+
0 commit comments