Skip to content

Commit 0375f20

Browse files
committed
fix integ test
1 parent 567c864 commit 0375f20

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

src/sagemaker/jumpstart/hub/hub.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,22 @@ def create(
9494
"""Creates a hub with the given description"""
9595
curr_timestamp = datetime.now().timestamp()
9696

97-
return self._sagemaker_session.create_hub(
98-
hub_name=self.hub_name,
99-
hub_description=description,
100-
hub_display_name=display_name,
101-
hub_search_keywords=search_keywords,
102-
s3_storage_config={
97+
request = {
98+
"hub_name": self.hub_name,
99+
"hub_description": description,
100+
"hub_display_name": display_name,
101+
"hub_search_keywords": search_keywords,
102+
"tags": tags
103+
}
104+
105+
if self.bucket_name:
106+
request["s3_storage_config"] = {
103107
"S3OutputPath": (
104108
f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}"
105-
if self.bucket_name
106-
else None
107109
)
108-
},
109-
tags=tags,
110-
)
110+
}
111+
112+
return self._sagemaker_session.create_hub(**request)
111113

112114
def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse:
113115
"""Returns descriptive information about the Hub"""

src/test.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import json
2+
import boto3
3+
from sagemaker.session import Session
4+
from sagemaker.jumpstart.model import JumpStartModel
5+
from sagemaker.jumpstart.estimator import JumpStartEstimator
6+
from sagemaker.enums import EndpointType
7+
from sagemaker.jumpstart.artifacts.predictors import _retrieve_default_content_type
8+
9+
model_id = "meta-textgeneration-llama-2-7b"
10+
hub_name = "bencrab-test"
11+
12+
model = JumpStartModel(
13+
model_id=model_id,
14+
hub_name=hub_name,
15+
)
16+
17+
# model.deploy(accept_eula=True, endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED)
18+
19+
_retrieve_default_content_type(
20+
model_id=model_id,
21+
model_version="*",
22+
region="us-west-2",
23+
hub_arn=f"arn:aws:sagemaker:us-west-2:802376408542:hub/{hub_name}",
24+
)

0 commit comments

Comments
 (0)