Skip to content

Commit 5e8e894

Browse files
bencrabtreenargokul
authored andcommitted
remove s3 output location requirement from hub class init (aws#5081)
* remove s3 output location requirement from hub class init * fix integ test hub * lint * fix test --------- Co-authored-by: Gokul Anantha Narayanan <[email protected]>
1 parent b872f3e commit 5e8e894

File tree

4 files changed

+25
-173
lines changed

4 files changed

+25
-173
lines changed

src/sagemaker/jumpstart/hub/hub.py

+16-53
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,25 @@
1616
from datetime import datetime
1717
import logging
1818
from typing import Optional, Dict, List, Any, Union
19-
from botocore import exceptions
2019

2120
from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME
2221
from sagemaker.jumpstart.enums import JumpStartScriptScope
2322
from sagemaker.session import Session
2423

25-
from sagemaker.jumpstart.constants import (
26-
JUMPSTART_LOGGER,
27-
)
2824
from sagemaker.jumpstart.types import (
2925
HubContentType,
3026
)
3127
from sagemaker.jumpstart.filters import Constant, Operator, BooleanValues
3228
from sagemaker.jumpstart.hub.utils import (
3329
get_hub_model_version,
3430
get_info_from_hub_resource_arn,
35-
create_hub_bucket_if_it_does_not_exist,
36-
generate_default_hub_bucket_name,
37-
create_s3_object_reference_from_uri,
3831
construct_hub_arn_from_name,
3932
)
4033

4134
from sagemaker.jumpstart.notebook_utils import (
4235
list_jumpstart_models,
4336
)
4437

45-
from sagemaker.jumpstart.hub.types import (
46-
S3ObjectLocation,
47-
)
4838
from sagemaker.jumpstart.hub.interfaces import (
4939
DescribeHubResponse,
5040
DescribeHubContentResponse,
@@ -66,8 +56,8 @@ class Hub:
6656
def __init__(
6757
self,
6858
hub_name: str,
59+
sagemaker_session: Session,
6960
bucket_name: Optional[str] = None,
70-
sagemaker_session: Optional[Session] = None,
7161
) -> None:
7262
"""Instantiates a SageMaker ``Hub``.
7363
@@ -78,41 +68,11 @@ def __init__(
7868
"""
7969
self.hub_name = hub_name
8070
self.region = sagemaker_session.boto_region_name
71+
self.bucket_name = bucket_name
8172
self._sagemaker_session = (
8273
sagemaker_session
8374
or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True)
8475
)
85-
self.hub_storage_location = self._generate_hub_storage_location(bucket_name)
86-
87-
def _fetch_hub_bucket_name(self) -> str:
88-
"""Retrieves hub bucket name from Hub config if exists"""
89-
try:
90-
hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name)
91-
hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath")
92-
if hub_output_location:
93-
location = create_s3_object_reference_from_uri(hub_output_location)
94-
return location.bucket
95-
default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
96-
JUMPSTART_LOGGER.warning(
97-
"There is not a Hub bucket associated with %s. Using %s",
98-
self.hub_name,
99-
default_bucket_name,
100-
)
101-
return default_bucket_name
102-
except exceptions.ClientError:
103-
hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
104-
JUMPSTART_LOGGER.warning(
105-
"There is not a Hub bucket associated with %s. Using %s",
106-
self.hub_name,
107-
hub_bucket_name,
108-
)
109-
return hub_bucket_name
110-
111-
def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None:
112-
"""Generates an ``S3ObjectLocation`` given a Hub name."""
113-
hub_bucket_name = bucket_name or self._fetch_hub_bucket_name()
114-
curr_timestamp = datetime.now().timestamp()
115-
return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")
11676

11777
def _get_latest_model_version(self, model_id: str) -> str:
11878
"""Populates the lastest version of a model from specs no matter what is passed.
@@ -132,19 +92,22 @@ def create(
13292
tags: Optional[str] = None,
13393
) -> Dict[str, str]:
13494
"""Creates a hub with the given description"""
95+
curr_timestamp = datetime.now().timestamp()
13596

136-
create_hub_bucket_if_it_does_not_exist(
137-
self.hub_storage_location.bucket, self._sagemaker_session
138-
)
97+
request = {
98+
"hub_name": self.hub_name,
99+
"hub_description": description,
100+
"hub_display_name": display_name,
101+
"hub_search_keywords": search_keywords,
102+
"tags": tags,
103+
}
139104

140-
return self._sagemaker_session.create_hub(
141-
hub_name=self.hub_name,
142-
hub_description=description,
143-
hub_display_name=display_name,
144-
hub_search_keywords=search_keywords,
145-
s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()},
146-
tags=tags,
147-
)
105+
if self.bucket_name:
106+
request["s3_storage_config"] = {
107+
"S3OutputPath": (f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}")
108+
}
109+
110+
return self._sagemaker_session.create_hub(**request)
148111

149112
def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse:
150113
"""Returns descriptive information about the Hub"""

src/sagemaker/jumpstart/hub/utils.py

-57
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from __future__ import absolute_import
1616
import re
1717
from typing import Optional, List, Any
18-
from sagemaker.jumpstart.hub.types import S3ObjectLocation
19-
from sagemaker.s3_utils import parse_s3_url
2018
from sagemaker.session import Session
2119
from sagemaker.utils import aws_partition
2220
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
@@ -139,61 +137,6 @@ def generate_hub_arn_for_init_kwargs(
139137
return hub_arn
140138

141139

142-
def generate_default_hub_bucket_name(
143-
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
144-
) -> str:
145-
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.
146-
147-
Returns:
148-
str: The name of the default bucket. If the name was not explicitly specified through
149-
the Session or sagemaker_config, the bucket will take the form:
150-
``sagemaker-hubs-{region}-{AWS account ID}``.
151-
"""
152-
153-
region: str = sagemaker_session.boto_region_name
154-
account_id: str = sagemaker_session.account_id()
155-
156-
# TODO: Validate and fast fail
157-
158-
return f"sagemaker-hubs-{region}-{account_id}"
159-
160-
161-
def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]:
162-
"""Utiity to help generate an S3 object reference"""
163-
if not s3_uri:
164-
return None
165-
166-
bucket, key = parse_s3_url(s3_uri)
167-
168-
return S3ObjectLocation(
169-
bucket=bucket,
170-
key=key,
171-
)
172-
173-
174-
def create_hub_bucket_if_it_does_not_exist(
175-
bucket_name: Optional[str] = None,
176-
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
177-
) -> str:
178-
"""Creates the default SageMaker Hub bucket if it does not exist.
179-
180-
Returns:
181-
str: The name of the default bucket. Takes the form:
182-
``sagemaker-hubs-{region}-{AWS account ID}``.
183-
"""
184-
185-
region: str = sagemaker_session.boto_region_name
186-
if bucket_name is None:
187-
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session)
188-
189-
sagemaker_session._create_s3_bucket_if_it_does_not_exist(
190-
bucket_name=bucket_name,
191-
region=region,
192-
)
193-
194-
return bucket_name
195-
196-
197140
def is_gated_bucket(bucket_name: str) -> bool:
198141
"""Returns true if the bucket name is the JumpStart gated bucket."""
199142
return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET

tests/unit/sagemaker/jumpstart/hub/test_hub.py

+9-22
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import pytest
1717
from mock import Mock
1818
from sagemaker.jumpstart.hub.hub import Hub
19-
from sagemaker.jumpstart.hub.types import S3ObjectLocation
2019

2120

2221
REGION = "us-east-1"
@@ -60,48 +59,34 @@ def test_instantiates(sagemaker_session):
6059

6160

6261
@pytest.mark.parametrize(
63-
("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"),
62+
("hub_name,hub_description,,hub_display_name,hub_search_keywords,tags"),
6463
[
65-
pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None),
64+
pytest.param("MockHub1", "this is my sagemaker hub", None, None, None),
6665
pytest.param(
6766
"MockHub2",
6867
"this is my sagemaker hub two",
69-
None,
7068
"DisplayMockHub2",
7169
["mock", "hub", "123"],
7270
[{"Key": "tag-key-1", "Value": "tag-value-1"}],
7371
),
7472
],
7573
)
76-
@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location")
7774
def test_create_with_no_bucket_name(
78-
mock_generate_hub_storage_location,
7975
sagemaker_session,
8076
hub_name,
8177
hub_description,
82-
hub_bucket_name,
8378
hub_display_name,
8479
hub_search_keywords,
8580
tags,
8681
):
87-
storage_location = S3ObjectLocation(
88-
"sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}"
89-
)
90-
mock_generate_hub_storage_location.return_value = storage_location
9182
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
9283
sagemaker_session.create_hub = Mock(return_value=create_hub)
93-
sagemaker_session.describe_hub.return_value = {
94-
"S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"}
95-
}
9684
hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session)
9785
request = {
9886
"hub_name": hub_name,
9987
"hub_description": hub_description,
10088
"hub_display_name": hub_display_name,
10189
"hub_search_keywords": hub_search_keywords,
102-
"s3_storage_config": {
103-
"S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}"
104-
},
10590
"tags": tags,
10691
}
10792
response = hub.create(
@@ -128,9 +113,9 @@ def test_create_with_no_bucket_name(
128113
),
129114
],
130115
)
131-
@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location")
116+
@patch("sagemaker.jumpstart.hub.hub.datetime")
132117
def test_create_with_bucket_name(
133-
mock_generate_hub_storage_location,
118+
mock_datetime,
134119
sagemaker_session,
135120
hub_name,
136121
hub_description,
@@ -139,8 +124,8 @@ def test_create_with_bucket_name(
139124
hub_search_keywords,
140125
tags,
141126
):
142-
storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}")
143-
mock_generate_hub_storage_location.return_value = storage_location
127+
mock_datetime.now.return_value = FAKE_TIME
128+
144129
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
145130
sagemaker_session.create_hub = Mock(return_value=create_hub)
146131
hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name)
@@ -149,7 +134,9 @@ def test_create_with_bucket_name(
149134
"hub_description": hub_description,
150135
"hub_display_name": hub_display_name,
151136
"hub_search_keywords": hub_search_keywords,
152-
"s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"},
137+
"s3_storage_config": {
138+
"S3OutputPath": f"s3://mock-bucket-123/{hub_name}-{FAKE_TIME.timestamp()}"
139+
},
153140
"tags": tags,
154141
}
155142
response = hub.create(

tests/unit/sagemaker/jumpstart/hub/test_utils.py

-41
Original file line numberDiff line numberDiff line change
@@ -173,30 +173,6 @@ def test_generate_hub_arn_for_init_kwargs():
173173
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
174174

175175

176-
def test_create_hub_bucket_if_it_does_not_exist_hub_arn():
177-
mock_sagemaker_session = Mock()
178-
mock_sagemaker_session.account_id.return_value = "123456789123"
179-
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
180-
"Account": "123456789123"
181-
}
182-
hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub"
183-
# Mock custom session with custom values
184-
mock_custom_session = Mock()
185-
mock_custom_session.account_id.return_value = "000000000000"
186-
mock_custom_session.boto_region_name = "us-east-2"
187-
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
188-
mock_sagemaker_session.boto_region_name = "us-east-1"
189-
190-
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
191-
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
192-
sagemaker_session=mock_sagemaker_session
193-
)
194-
195-
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
196-
assert created_hub_bucket_name == bucket_name
197-
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
198-
199-
200176
def test_is_gated_bucket():
201177
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True
202178

@@ -207,23 +183,6 @@ def test_is_gated_bucket():
207183
assert utils.is_gated_bucket("") is False
208184

209185

210-
def test_create_hub_bucket_if_it_does_not_exist():
211-
mock_sagemaker_session = Mock()
212-
mock_sagemaker_session.account_id.return_value = "123456789123"
213-
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
214-
"Account": "123456789123"
215-
}
216-
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
217-
mock_sagemaker_session.boto_region_name = "us-east-1"
218-
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
219-
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
220-
sagemaker_session=mock_sagemaker_session
221-
)
222-
223-
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
224-
assert created_hub_bucket_name == bucket_name
225-
226-
227186
@patch("sagemaker.session.Session")
228187
def test_get_hub_model_version_success(mock_session):
229188
hub_name = "test_hub"

0 commit comments

Comments
 (0)