Skip to content

Commit ad6805d

Browse files
liyanzhang505huiguangjun
authored andcommitted
ecs ram role credentials provider
1 parent 7769f1e commit ad6805d

File tree

6 files changed

+408
-32
lines changed

6 files changed

+408
-32
lines changed

oss2/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from . import models, exceptions, defaults
44

55
from .api import Service, Bucket
6-
from .auth import Auth, AuthV2, AnonymousAuth, StsAuth, AUTH_VERSION_1, AUTH_VERSION_2, make_auth
6+
from .auth import Auth, AuthV2, AnonymousAuth, StsAuth, AUTH_VERSION_1, AUTH_VERSION_2, make_auth, ProviderAuth, ProviderAuthV2
77
from .http import Session, CaseInsensitiveDict
8-
8+
from .credentials import EcsRamRoleCredentialsProvider, EcsRamRoleCredential, CredentialsProvider, StaticCredentialsProvider
99

1010
from .iterators import (BucketIterator, ObjectIterator, ObjectIteratorV2,
1111
MultipartUploadIterator, ObjectUploadIterator,

oss2/auth.py

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .compat import urlquote, to_bytes, is_py2
99
from .headers import *
1010
import logging
11+
from .credentials import StaticCredentialsProvider
1112

1213
AUTH_VERSION_1 = 'v1'
1314
AUTH_VERSION_2 = 'v2'
@@ -26,11 +27,14 @@ def make_auth(access_key_id, access_key_secret, auth_version=AUTH_VERSION_1):
2627

2728
class AuthBase(object):
2829
"""用于保存用户AccessKeyId、AccessKeySecret,以及计算签名的对象。"""
29-
def __init__(self, access_key_id, access_key_secret):
30-
self.id = access_key_id.strip()
31-
self.secret = access_key_secret.strip()
30+
def __init__(self, credentials_provider):
31+
self.credentials_provider = credentials_provider
3232

3333
def _sign_rtmp_url(self, url, bucket_name, channel_name, expires, params):
34+
credentials = self.credentials_provider.get_credentials()
35+
if credentials.get_security_token():
36+
params['security-token'] = credentials.get_security_token()
37+
3438
expiration_time = int(time.time()) + expires
3539

3640
canonicalized_resource = "/%s/%s" % (bucket_name, channel_name)
@@ -52,18 +56,21 @@ def _sign_rtmp_url(self, url, bucket_name, channel_name, expires, params):
5256
logger.debug('Sign Rtmp url: string to be signed = {0}'.format(string_to_sign))
5357

5458

55-
h = hmac.new(to_bytes(self.secret), to_bytes(string_to_sign), hashlib.sha1)
59+
h = hmac.new(to_bytes(credentials.get_access_key_secret()), to_bytes(string_to_sign), hashlib.sha1)
5660
signature = utils.b64encode_as_string(h.digest())
5761

58-
p['OSSAccessKeyId'] = self.id
62+
p['OSSAccessKeyId'] = credentials.get_access_key_id()
5963
p['Expires'] = str(expiration_time)
6064
p['Signature'] = signature
6165

6266
return url + '?' + '&'.join(_param_to_quoted_query(k, v) for k, v in p.items())
6367

6468

65-
class Auth(AuthBase):
66-
"""签名版本1"""
69+
70+
class ProviderAuth(AuthBase):
71+
"""签名版本1
72+
默认构造函数同父类AuthBase,需要传递credentials_provider
73+
"""
6774
_subresource_key_set = frozenset(
6875
['response-content-type', 'response-content-language',
6976
'response-cache-control', 'logging', 'response-content-encoding',
@@ -80,32 +87,40 @@ class Auth(AuthBase):
8087
)
8188

8289
def _sign_request(self, req, bucket_name, key):
90+
credentials = self.credentials_provider.get_credentials()
91+
if credentials.get_security_token():
92+
req.headers[OSS_SECURITY_TOKEN] = credentials.get_security_token()
93+
8394
req.headers['date'] = utils.http_date()
8495

85-
signature = self.__make_signature(req, bucket_name, key)
86-
req.headers['authorization'] = "OSS {0}:{1}".format(self.id, signature)
96+
signature = self.__make_signature(req, bucket_name, key, credentials)
97+
req.headers['authorization'] = "OSS {0}:{1}".format(credentials.get_access_key_id(), signature)
8798

8899
def _sign_url(self, req, bucket_name, key, expires):
100+
credentials = self.credentials_provider.get_credentials()
101+
if credentials.get_security_token():
102+
req.params['security-token'] = credentials.get_security_token()
103+
89104
expiration_time = int(time.time()) + expires
90105

91106
req.headers['date'] = str(expiration_time)
92-
signature = self.__make_signature(req, bucket_name, key)
107+
signature = self.__make_signature(req, bucket_name, key, credentials)
93108

94-
req.params['OSSAccessKeyId'] = self.id
109+
req.params['OSSAccessKeyId'] = credentials.get_access_key_id()
95110
req.params['Expires'] = str(expiration_time)
96111
req.params['Signature'] = signature
97112

98113
return req.url + '?' + '&'.join(_param_to_quoted_query(k, v) for k, v in req.params.items())
99114

100-
def __make_signature(self, req, bucket_name, key):
115+
def __make_signature(self, req, bucket_name, key, credentials):
101116
if is_py2:
102117
string_to_sign = self.__get_string_to_sign(req, bucket_name, key)
103118
else:
104119
string_to_sign = self.__get_bytes_to_sign(req, bucket_name, key)
105120

106121
logger.debug('Make signature: string to be signed = {0}'.format(string_to_sign))
107122

108-
h = hmac.new(to_bytes(self.secret), to_bytes(string_to_sign), hashlib.sha1)
123+
h = hmac.new(to_bytes(credentials.get_access_key_secret()), to_bytes(string_to_sign), hashlib.sha1)
109124
return utils.b64encode_as_string(h.digest())
110125

111126
def __get_string_to_sign(self, req, bucket_name, key):
@@ -192,6 +207,14 @@ def __get_headers_bytes(self, req):
192207
else:
193208
return b''
194209

210+
class Auth(ProviderAuth):
211+
"""签名版本1
212+
"""
213+
def __init__(self, access_key_id, access_key_secret):
214+
credentials_provider = StaticCredentialsProvider(access_key_id.strip(), access_key_secret.strip())
215+
super(Auth, self).__init__(credentials_provider)
216+
217+
195218
class AnonymousAuth(object):
196219
"""用于匿名访问。
197220
@@ -222,19 +245,16 @@ class StsAuth(object):
222245
"""
223246
def __init__(self, access_key_id, access_key_secret, security_token, auth_version=AUTH_VERSION_1):
224247
logger.debug("Init StsAuth: access_key_id: {0}, access_key_secret: ******, security_token: ******".format(access_key_id))
225-
self.__auth = make_auth(access_key_id, access_key_secret, auth_version)
226-
self.__security_token = security_token
248+
credentials_provider = StaticCredentialsProvider(access_key_id, access_key_secret, security_token)
249+
self.__auth = ProviderAuthV2(credentials_provider) if auth_version == AUTH_VERSION_2 else ProviderAuth(credentials_provider)
227250

228251
def _sign_request(self, req, bucket_name, key):
229-
req.headers[OSS_SECURITY_TOKEN] = self.__security_token
230252
self.__auth._sign_request(req, bucket_name, key)
231253

232254
def _sign_url(self, req, bucket_name, key, expires):
233-
req.params['security-token'] = self.__security_token
234255
return self.__auth._sign_url(req, bucket_name, key, expires)
235256

236257
def _sign_rtmp_url(self, url, bucket_name, channel_name, expires, params):
237-
params['security-token'] = self.__security_token
238258
return self.__auth._sign_rtmp_url(url, bucket_name, channel_name, expires, params)
239259

240260

@@ -268,8 +288,9 @@ def v2_uri_encode(raw_text):
268288
'if-modified-since'])
269289

270290

271-
class AuthV2(AuthBase):
272-
"""签名版本2,与版本1的区别在:
291+
class ProviderAuthV2(AuthBase):
292+
"""签名版本2,默认构造函数同父类AuthBase,需要传递credentials_provider
293+
与版本1的区别在:
273294
1. 使用SHA256算法,具有更高的安全性
274295
2. 参数计算包含所有的HTTP查询参数
275296
"""
@@ -283,20 +304,24 @@ def _sign_request(self, req, bucket_name, key, in_additional_headers=None):
283304
:param key: OSS文件名
284305
:param in_additional_headers: 加入签名计算的额外header列表
285306
"""
307+
credentials = self.credentials_provider.get_credentials()
308+
if credentials.get_security_token():
309+
req.headers[OSS_SECURITY_TOKEN] = credentials.get_security_token()
310+
286311
if in_additional_headers is None:
287312
in_additional_headers = _DEFAULT_ADDITIONAL_HEADERS
288313

289314
additional_headers = self.__get_additional_headers(req, in_additional_headers)
290315

291316
req.headers['date'] = utils.http_date()
292317

293-
signature = self.__make_signature(req, bucket_name, key, additional_headers)
318+
signature = self.__make_signature(req, bucket_name, key, additional_headers, credentials)
294319

295320
if additional_headers:
296321
req.headers['authorization'] = "OSS2 AccessKeyId:{0},AdditionalHeaders:{1},Signature:{2}"\
297-
.format(self.id, ';'.join(additional_headers), signature)
322+
.format(credentials.get_access_key_id(), ';'.join(additional_headers), signature)
298323
else:
299-
req.headers['authorization'] = "OSS2 AccessKeyId:{0},Signature:{1}".format(self.id, signature)
324+
req.headers['authorization'] = "OSS2 AccessKeyId:{0},Signature:{1}".format(credentials.get_access_key_id(), signature)
300325

301326
def _sign_url(self, req, bucket_name, key, expires, in_additional_headers=None):
302327
"""返回一个签过名的URL
@@ -311,6 +336,9 @@ def _sign_url(self, req, bucket_name, key, expires, in_additional_headers=None):
311336
312337
:return: a signed URL
313338
"""
339+
credentials = self.credentials_provider.get_credentials()
340+
if credentials.get_security_token():
341+
req.params['security-token'] = credentials.get_security_token()
314342

315343
if in_additional_headers is None:
316344
in_additional_headers = set()
@@ -323,23 +351,23 @@ def _sign_url(self, req, bucket_name, key, expires, in_additional_headers=None):
323351

324352
req.params['x-oss-signature-version'] = 'OSS2'
325353
req.params['x-oss-expires'] = str(expiration_time)
326-
req.params['x-oss-access-key-id'] = self.id
354+
req.params['x-oss-access-key-id'] = credentials.get_access_key_id()
327355

328-
signature = self.__make_signature(req, bucket_name, key, additional_headers)
356+
signature = self.__make_signature(req, bucket_name, key, additional_headers, credentials)
329357

330358
req.params['x-oss-signature'] = signature
331359

332360
return req.url + '?' + '&'.join(_param_to_quoted_query(k, v) for k, v in req.params.items())
333361

334-
def __make_signature(self, req, bucket_name, key, additional_headers):
362+
def __make_signature(self, req, bucket_name, key, additional_headers, credentials):
335363
if is_py2:
336364
string_to_sign = self.__get_string_to_sign(req, bucket_name, key, additional_headers)
337365
else:
338366
string_to_sign = self.__get_bytes_to_sign(req, bucket_name, key, additional_headers)
339367

340368
logger.debug('Make signature: string to be signed = {0}'.format(string_to_sign))
341369

342-
h = hmac.new(to_bytes(self.secret), to_bytes(string_to_sign), hashlib.sha256)
370+
h = hmac.new(to_bytes(credentials.get_access_key_secret()), to_bytes(string_to_sign), hashlib.sha256)
343371
return utils.b64encode_as_string(h.digest())
344372

345373
def __get_additional_headers(self, req, in_additional_headers):
@@ -427,7 +455,7 @@ def __get_bytes_to_sign(self, req, bucket_name, key, additional_header_list):
427455
canonicalized_oss_headers +\
428456
additional_headers + b'\n' +\
429457
canonicalized_resource
430-
458+
431459
def __get_canonicalized_oss_headers_bytes(self, req, additional_headers):
432460
"""
433461
:param additional_headers: 小写的headers列表, 并且这些headers都不以'x-oss-'为前缀.
@@ -442,3 +470,13 @@ def __get_canonicalized_oss_headers_bytes(self, req, additional_headers):
442470
canon_headers.sort(key=lambda x: x[0])
443471

444472
return b''.join(to_bytes(v[0]) + b':' + to_bytes(v[1]) + b'\n' for v in canon_headers)
473+
474+
475+
class AuthV2(ProviderAuthV2):
476+
"""签名版本2,与版本1的区别在:
477+
1. 使用SHA256算法,具有更高的安全性
478+
2. 参数计算包含所有的HTTP查询参数
479+
"""
480+
def __init__(self, access_key_id, access_key_secret):
481+
credentials_provider = StaticCredentialsProvider(access_key_id.strip(), access_key_secret.strip())
482+
super(AuthV2, self).__init__(credentials_provider)

oss2/credentials.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import time
4+
import requests
5+
import json
6+
import logging
7+
import threading
8+
from .exceptions import ClientError
9+
from .utils import to_unixtime
10+
from .compat import to_unicode
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class Credentials(object):
16+
def __init__(self, access_key_id="", access_key_secret="", security_token=""):
17+
self.access_key_id = access_key_id
18+
self.access_key_secret = access_key_secret
19+
self.security_token = security_token
20+
21+
def get_access_key_id(self):
22+
return self.access_key_id
23+
24+
def get_access_key_secret(self):
25+
return self.access_key_secret
26+
27+
def get_security_token(self):
28+
return self.security_token
29+
30+
31+
DEFAULT_ECS_SESSION_TOKEN_DURATION_SECONDS = 3600 * 6
32+
DEFAULT_ECS_SESSION_EXPIRED_FACTOR = 0.85
33+
34+
35+
class EcsRamRoleCredential(Credentials):
36+
def __init__(self,
37+
access_key_id,
38+
access_key_secret,
39+
security_token,
40+
expiration,
41+
duration,
42+
expired_factor=None):
43+
self.access_key_id = access_key_id
44+
self.access_key_secret = access_key_secret
45+
self.security_token = security_token
46+
self.expiration = expiration
47+
self.duration = duration
48+
self.expired_factor = expired_factor or DEFAULT_ECS_SESSION_EXPIRED_FACTOR
49+
50+
def get_access_key_id(self):
51+
return self.access_key_id
52+
53+
def get_access_key_secret(self):
54+
return self.access_key_secret
55+
56+
def get_security_token(self):
57+
return self.security_token
58+
59+
def will_soon_expire(self):
60+
now = int(time.time())
61+
return self.duration * (1.0 - self.expired_factor) > self.expiration - now
62+
63+
64+
class CredentialsProvider(object):
65+
def get_credentials(self):
66+
return
67+
68+
69+
class StaticCredentialsProvider(CredentialsProvider):
70+
def __init__(self, access_key_id="", access_key_secret="", security_token=""):
71+
self.credentials = Credentials(access_key_id, access_key_secret, security_token)
72+
73+
def get_credentials(self):
74+
return self.credentials
75+
76+
77+
class EcsRamRoleCredentialsProvider(CredentialsProvider):
78+
def __init__(self, auth_host, max_retries=3, timeout=10):
79+
self.fetcher = EcsRamRoleCredentialsFetcher(auth_host)
80+
self.max_retries = max_retries
81+
self.timeout = timeout
82+
self.credentials = None
83+
self.__lock = threading.Lock()
84+
85+
def get_credentials(self):
86+
if self.credentials is None or self.credentials.will_soon_expire():
87+
with self.__lock:
88+
if self.credentials is None or self.credentials.will_soon_expire():
89+
try:
90+
self.credentials = self.fetcher.fetch(self.max_retries, self.timeout)
91+
except Exception as e:
92+
logger.error("Exception: {0}".format(e))
93+
if self.credentials is None:
94+
raise
95+
96+
return self.credentials
97+
98+
99+
class EcsRamRoleCredentialsFetcher(object):
100+
def __init__(self, auth_host):
101+
self.auth_host = auth_host
102+
103+
def fetch(self, retry_times=3, timeout=10):
104+
for i in range(0, retry_times):
105+
try:
106+
response = requests.get(self.auth_host, timeout=timeout)
107+
if response.status_code != 200:
108+
raise ClientError(
109+
"Failed to fetch credentials url, http code:{0}, msg:{1}".format(response.status_code,
110+
response.text))
111+
dic = json.loads(to_unicode(response.content))
112+
code = dic.get('Code')
113+
access_key_id = dic.get('AccessKeyId')
114+
access_key_secret = dic.get('AccessKeySecret')
115+
security_token = dic.get('SecurityToken')
116+
expiration_date = dic.get('Expiration')
117+
last_updated_date = dic.get('LastUpdated')
118+
119+
if code != "Success":
120+
raise ClientError("Get credentials from ECS metadata service error, code: {0}".format(code))
121+
122+
expiration_stamp = to_unixtime(expiration_date, "%Y-%m-%dT%H:%M:%SZ")
123+
duration = DEFAULT_ECS_SESSION_TOKEN_DURATION_SECONDS
124+
if last_updated_date is not None:
125+
last_updated_stamp = to_unixtime(last_updated_date, "%Y-%m-%dT%H:%M:%SZ")
126+
duration = expiration_stamp - last_updated_stamp
127+
return EcsRamRoleCredential(access_key_id, access_key_secret, security_token, expiration_stamp,
128+
duration, DEFAULT_ECS_SESSION_EXPIRED_FACTOR)
129+
except Exception as e:
130+
if i == retry_times - 1:
131+
logger.error("Exception: {0}".format(e))
132+
raise ClientError("Failed to get credentials from ECS metadata service. {0}".format(e))

tests/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
OSS_INVENTORY_BUCKET_DESTINATION_ACCOUNT = os.getenv("OSS_TEST_RAM_UID")
3737

3838
OSS_AUTH_VERSION = None
39+
OSS_TEST_AUTH_SERVER_HOST = os.getenv("OSS_TEST_AUTH_SERVER_HOST")
3940

4041
private_key = RSA.generate(1024)
4142
public_key = private_key.publickey()

0 commit comments

Comments
 (0)