Skip to content

Commit 567c864

Browse files
committed
remove s3 output location requirement from hub class init
1 parent 6945a04 commit 567c864

File tree

4 files changed

+19
-166
lines changed

4 files changed

+19
-166
lines changed

src/sagemaker/jumpstart/hub/hub.py

+9-46
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,25 @@
1616
from datetime import datetime
1717
import logging
1818
from typing import Optional, Dict, List, Any, Union
19-
from botocore import exceptions
2019

2120
from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME
2221
from sagemaker.jumpstart.enums import JumpStartScriptScope
2322
from sagemaker.session import Session
2423

25-
from sagemaker.jumpstart.constants import (
26-
JUMPSTART_LOGGER,
27-
)
2824
from sagemaker.jumpstart.types import (
2925
HubContentType,
3026
)
3127
from sagemaker.jumpstart.filters import Constant, Operator, BooleanValues
3228
from sagemaker.jumpstart.hub.utils import (
3329
get_hub_model_version,
3430
get_info_from_hub_resource_arn,
35-
create_hub_bucket_if_it_does_not_exist,
36-
generate_default_hub_bucket_name,
37-
create_s3_object_reference_from_uri,
3831
construct_hub_arn_from_name,
3932
)
4033

4134
from sagemaker.jumpstart.notebook_utils import (
4235
list_jumpstart_models,
4336
)
4437

45-
from sagemaker.jumpstart.hub.types import (
46-
S3ObjectLocation,
47-
)
4838
from sagemaker.jumpstart.hub.interfaces import (
4939
DescribeHubResponse,
5040
DescribeHubContentResponse,
@@ -78,41 +68,11 @@ def __init__(
7868
"""
7969
self.hub_name = hub_name
8070
self.region = sagemaker_session.boto_region_name
71+
self.bucket_name = bucket_name
8172
self._sagemaker_session = (
8273
sagemaker_session
8374
or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True)
8475
)
85-
self.hub_storage_location = self._generate_hub_storage_location(bucket_name)
86-
87-
def _fetch_hub_bucket_name(self) -> str:
88-
"""Retrieves hub bucket name from Hub config if exists"""
89-
try:
90-
hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name)
91-
hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath")
92-
if hub_output_location:
93-
location = create_s3_object_reference_from_uri(hub_output_location)
94-
return location.bucket
95-
default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
96-
JUMPSTART_LOGGER.warning(
97-
"There is not a Hub bucket associated with %s. Using %s",
98-
self.hub_name,
99-
default_bucket_name,
100-
)
101-
return default_bucket_name
102-
except exceptions.ClientError:
103-
hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
104-
JUMPSTART_LOGGER.warning(
105-
"There is not a Hub bucket associated with %s. Using %s",
106-
self.hub_name,
107-
hub_bucket_name,
108-
)
109-
return hub_bucket_name
110-
111-
def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None:
112-
"""Generates an ``S3ObjectLocation`` given a Hub name."""
113-
hub_bucket_name = bucket_name or self._fetch_hub_bucket_name()
114-
curr_timestamp = datetime.now().timestamp()
115-
return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")
11676

11777
def _get_latest_model_version(self, model_id: str) -> str:
11878
"""Populates the lastest version of a model from specs no matter what is passed.
@@ -132,17 +92,20 @@ def create(
13292
tags: Optional[str] = None,
13393
) -> Dict[str, str]:
13494
"""Creates a hub with the given description"""
135-
136-
create_hub_bucket_if_it_does_not_exist(
137-
self.hub_storage_location.bucket, self._sagemaker_session
138-
)
95+
curr_timestamp = datetime.now().timestamp()
13996

14097
return self._sagemaker_session.create_hub(
14198
hub_name=self.hub_name,
14299
hub_description=description,
143100
hub_display_name=display_name,
144101
hub_search_keywords=search_keywords,
145-
s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()},
102+
s3_storage_config={
103+
"S3OutputPath": (
104+
f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}"
105+
if self.bucket_name
106+
else None
107+
)
108+
},
146109
tags=tags,
147110
)
148111

src/sagemaker/jumpstart/hub/utils.py

-57
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from __future__ import absolute_import
1616
import re
1717
from typing import Optional, List, Any
18-
from sagemaker.jumpstart.hub.types import S3ObjectLocation
19-
from sagemaker.s3_utils import parse_s3_url
2018
from sagemaker.session import Session
2119
from sagemaker.utils import aws_partition
2220
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
@@ -138,61 +136,6 @@ def generate_hub_arn_for_init_kwargs(
138136
return hub_arn
139137

140138

141-
def generate_default_hub_bucket_name(
142-
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
143-
) -> str:
144-
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.
145-
146-
Returns:
147-
str: The name of the default bucket. If the name was not explicitly specified through
148-
the Session or sagemaker_config, the bucket will take the form:
149-
``sagemaker-hubs-{region}-{AWS account ID}``.
150-
"""
151-
152-
region: str = sagemaker_session.boto_region_name
153-
account_id: str = sagemaker_session.account_id()
154-
155-
# TODO: Validate and fast fail
156-
157-
return f"sagemaker-hubs-{region}-{account_id}"
158-
159-
160-
def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]:
161-
"""Utiity to help generate an S3 object reference"""
162-
if not s3_uri:
163-
return None
164-
165-
bucket, key = parse_s3_url(s3_uri)
166-
167-
return S3ObjectLocation(
168-
bucket=bucket,
169-
key=key,
170-
)
171-
172-
173-
def create_hub_bucket_if_it_does_not_exist(
174-
bucket_name: Optional[str] = None,
175-
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
176-
) -> str:
177-
"""Creates the default SageMaker Hub bucket if it does not exist.
178-
179-
Returns:
180-
str: The name of the default bucket. Takes the form:
181-
``sagemaker-hubs-{region}-{AWS account ID}``.
182-
"""
183-
184-
region: str = sagemaker_session.boto_region_name
185-
if bucket_name is None:
186-
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session)
187-
188-
sagemaker_session._create_s3_bucket_if_it_does_not_exist(
189-
bucket_name=bucket_name,
190-
region=region,
191-
)
192-
193-
return bucket_name
194-
195-
196139
def is_gated_bucket(bucket_name: str) -> bool:
197140
"""Returns true if the bucket name is the JumpStart gated bucket."""
198141
return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET

tests/unit/sagemaker/jumpstart/hub/test_hub.py

+10-22
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import pytest
1717
from mock import Mock
1818
from sagemaker.jumpstart.hub.hub import Hub
19-
from sagemaker.jumpstart.hub.types import S3ObjectLocation
2019

2120

2221
REGION = "us-east-1"
@@ -60,48 +59,35 @@ def test_instantiates(sagemaker_session):
6059

6160

6261
@pytest.mark.parametrize(
63-
("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"),
62+
("hub_name,hub_description,,hub_display_name,hub_search_keywords,tags"),
6463
[
65-
pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None),
64+
pytest.param("MockHub1", "this is my sagemaker hub", None, None, None),
6665
pytest.param(
6766
"MockHub2",
6867
"this is my sagemaker hub two",
69-
None,
7068
"DisplayMockHub2",
7169
["mock", "hub", "123"],
7270
[{"Key": "tag-key-1", "Value": "tag-value-1"}],
7371
),
7472
],
7573
)
76-
@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location")
7774
def test_create_with_no_bucket_name(
78-
mock_generate_hub_storage_location,
7975
sagemaker_session,
8076
hub_name,
8177
hub_description,
82-
hub_bucket_name,
8378
hub_display_name,
8479
hub_search_keywords,
8580
tags,
8681
):
87-
storage_location = S3ObjectLocation(
88-
"sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}"
89-
)
90-
mock_generate_hub_storage_location.return_value = storage_location
9182
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
9283
sagemaker_session.create_hub = Mock(return_value=create_hub)
93-
sagemaker_session.describe_hub.return_value = {
94-
"S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"}
95-
}
9684
hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session)
9785
request = {
9886
"hub_name": hub_name,
9987
"hub_description": hub_description,
10088
"hub_display_name": hub_display_name,
10189
"hub_search_keywords": hub_search_keywords,
102-
"s3_storage_config": {
103-
"S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}"
104-
},
90+
"s3_storage_config": {"S3OutputPath": None},
10591
"tags": tags,
10692
}
10793
response = hub.create(
@@ -128,9 +114,9 @@ def test_create_with_no_bucket_name(
128114
),
129115
],
130116
)
131-
@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location")
117+
@patch("sagemaker.jumpstart.hub.hub.datetime")
132118
def test_create_with_bucket_name(
133-
mock_generate_hub_storage_location,
119+
mock_datetime,
134120
sagemaker_session,
135121
hub_name,
136122
hub_description,
@@ -139,8 +125,8 @@ def test_create_with_bucket_name(
139125
hub_search_keywords,
140126
tags,
141127
):
142-
storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}")
143-
mock_generate_hub_storage_location.return_value = storage_location
128+
mock_datetime.now.return_value = FAKE_TIME
129+
144130
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
145131
sagemaker_session.create_hub = Mock(return_value=create_hub)
146132
hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name)
@@ -149,7 +135,9 @@ def test_create_with_bucket_name(
149135
"hub_description": hub_description,
150136
"hub_display_name": hub_display_name,
151137
"hub_search_keywords": hub_search_keywords,
152-
"s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"},
138+
"s3_storage_config": {
139+
"S3OutputPath": f"s3://mock-bucket-123/{hub_name}-{FAKE_TIME.timestamp()}"
140+
},
153141
"tags": tags,
154142
}
155143
response = hub.create(

tests/unit/sagemaker/jumpstart/hub/test_utils.py

-41
Original file line numberDiff line numberDiff line change
@@ -173,30 +173,6 @@ def test_generate_hub_arn_for_init_kwargs():
173173
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
174174

175175

176-
def test_create_hub_bucket_if_it_does_not_exist_hub_arn():
177-
mock_sagemaker_session = Mock()
178-
mock_sagemaker_session.account_id.return_value = "123456789123"
179-
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
180-
"Account": "123456789123"
181-
}
182-
hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub"
183-
# Mock custom session with custom values
184-
mock_custom_session = Mock()
185-
mock_custom_session.account_id.return_value = "000000000000"
186-
mock_custom_session.boto_region_name = "us-east-2"
187-
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
188-
mock_sagemaker_session.boto_region_name = "us-east-1"
189-
190-
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
191-
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
192-
sagemaker_session=mock_sagemaker_session
193-
)
194-
195-
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
196-
assert created_hub_bucket_name == bucket_name
197-
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
198-
199-
200176
def test_is_gated_bucket():
201177
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True
202178

@@ -207,23 +183,6 @@ def test_is_gated_bucket():
207183
assert utils.is_gated_bucket("") is False
208184

209185

210-
def test_create_hub_bucket_if_it_does_not_exist():
211-
mock_sagemaker_session = Mock()
212-
mock_sagemaker_session.account_id.return_value = "123456789123"
213-
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
214-
"Account": "123456789123"
215-
}
216-
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
217-
mock_sagemaker_session.boto_region_name = "us-east-1"
218-
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
219-
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
220-
sagemaker_session=mock_sagemaker_session
221-
)
222-
223-
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
224-
assert created_hub_bucket_name == bucket_name
225-
226-
227186
@patch("sagemaker.session.Session")
228187
def test_get_hub_model_version_success(mock_session):
229188
hub_name = "test_hub"

0 commit comments

Comments
 (0)