Skip to content

Commit c5f9c77

Browse files
keshav-chandakKeshav Chandak
authored andcommitted
fix: Added check for the presence of model package group before creating one (#5063)
Co-authored-by: Keshav Chandak <[email protected]>
1 parent 7b21cb1 commit c5f9c77

File tree

2 files changed

+122
-4
lines changed

2 files changed

+122
-4
lines changed

src/sagemaker/session.py

+52-4
Original file line numberDiff line numberDiff line change
@@ -4347,11 +4347,59 @@ def submit(request):
43474347
if model_package_group_name is not None and not model_package_group_name.startswith(
43484348
"arn:"
43494349
):
4350-
_create_resource(
4351-
lambda: self.sagemaker_client.create_model_package_group(
4352-
ModelPackageGroupName=request["ModelPackageGroupName"]
4350+
is_model_package_group_present = False
4351+
try:
4352+
model_package_groups_response = self.search(
4353+
resource="ModelPackageGroup",
4354+
search_expression={
4355+
"Filters": [
4356+
{
4357+
"Name": "ModelPackageGroupName",
4358+
"Value": request["ModelPackageGroupName"],
4359+
"Operator": "Equals",
4360+
}
4361+
],
4362+
},
4363+
)
4364+
if len(model_package_groups_response.get("Results")) > 0:
4365+
is_model_package_group_present = True
4366+
except Exception: # pylint: disable=W0703
4367+
model_package_groups = []
4368+
model_package_groups_response = self.sagemaker_client.list_model_package_groups(
4369+
NameContains=request["ModelPackageGroupName"],
4370+
)
4371+
model_package_groups = (
4372+
model_package_groups
4373+
+ model_package_groups_response["ModelPackageGroupSummaryList"]
4374+
)
4375+
next_token = model_package_groups_response.get("NextToken")
4376+
4377+
while next_token is not None and next_token != "":
4378+
model_package_groups_response = (
4379+
self.sagemaker_client.list_model_package_groups(
4380+
NameContains=request["ModelPackageGroupName"], NextToken=next_token
4381+
)
4382+
)
4383+
model_package_groups = (
4384+
model_package_groups
4385+
+ model_package_groups_response["ModelPackageGroupSummaryList"]
4386+
)
4387+
next_token = model_package_groups_response.get("NextToken")
4388+
4389+
filtered_model_package_group = list(
4390+
filter(
4391+
lambda mpg: mpg.get("ModelPackageGroupName")
4392+
== request["ModelPackageGroupName"],
4393+
model_package_groups,
4394+
)
4395+
)
4396+
is_model_package_group_present = len(filtered_model_package_group) > 0
4397+
if not is_model_package_group_present:
4398+
_create_resource(
4399+
lambda: self.sagemaker_client.create_model_package_group(
4400+
ModelPackageGroupName=request["ModelPackageGroupName"]
4401+
)
43534402
)
4354-
)
43554403
if "SourceUri" in request and request["SourceUri"] is not None:
43564404
# Remove inference spec from request if the
43574405
# given source uri can lead to auto-population of it

tests/unit/test_session.py

+70
Original file line numberDiff line numberDiff line change
@@ -5006,6 +5006,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session)
50065006
domain = "COMPUTER_VISION"
50075007
task = "IMAGE_CLASSIFICATION"
50085008
sample_payload_url = "s3://test-bucket/model"
5009+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
50095010
sagemaker_session.create_model_package_from_containers(
50105011
containers=containers,
50115012
content_types=content_types,
@@ -5094,6 +5095,8 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec
50945095
skip_model_validation = "All"
50955096
source_uri = "dummy-source-uri"
50965097

5098+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
5099+
50975100
created_versioned_mp_arn = (
50985101
"arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
50995102
)
@@ -5149,6 +5152,7 @@ def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp
51495152
approval_status = ("Approved",)
51505153
skip_model_validation = "All"
51515154
source_uri = "dummy-source-uri"
5155+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
51525156

51535157
with pytest.raises(
51545158
ValueError,
@@ -5221,6 +5225,8 @@ def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemake
52215225
return_value={"ModelPackageArn": created_versioned_mp_arn}
52225226
)
52235227

5228+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
5229+
52245230
sagemaker_session.create_model_package_from_containers(
52255231
model_package_group_name=model_package_group_name,
52265232
containers=containers,
@@ -5443,6 +5449,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s
54435449
approval_status = ("Approved",)
54445450
description = "description"
54455451
customer_metadata_properties = {"key1": "value1"}
5452+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
54465453
sagemaker_session.create_model_package_from_containers(
54475454
containers=containers,
54485455
content_types=content_types,
@@ -5510,6 +5517,7 @@ def test_create_model_package_from_containers_with_one_instance_types(
55105517
approval_status = ("Approved",)
55115518
description = "description"
55125519
customer_metadata_properties = {"key1": "value1"}
5520+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
55135521
sagemaker_session.create_model_package_from_containers(
55145522
containers=containers,
55155523
content_types=content_types,
@@ -7183,3 +7191,65 @@ def test_delete_hub_content_reference(sagemaker_session):
71837191
}
71847192

71857193
sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request)
7194+
7195+
7196+
def test_create_model_package_from_containers_to_create_mpg_if_not_present_without_search(
7197+
sagemaker_session,
7198+
):
7199+
sagemaker_session.sagemaker_client.search.side_effect = Exception()
7200+
sagemaker_session.sagemaker_client.search.return_value = {}
7201+
sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [
7202+
{
7203+
"ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg"}],
7204+
"NextToken": "NextToken",
7205+
},
7206+
{"ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg-test"}]},
7207+
]
7208+
sagemaker_session.create_model_package_from_containers(
7209+
source_uri="mock-source-uri", model_package_group_name="mock-mpg"
7210+
)
7211+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7212+
sagemaker_session.create_model_package_from_containers(
7213+
source_uri="mock-source-uri",
7214+
model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg",
7215+
)
7216+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7217+
sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [
7218+
{"ModelPackageGroupSummaryList": []}
7219+
]
7220+
sagemaker_session.create_model_package_from_containers(
7221+
source_uri="mock-source-uri", model_package_group_name="mock-mpg"
7222+
)
7223+
sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with(
7224+
ModelPackageGroupName="mock-mpg"
7225+
)
7226+
7227+
7228+
def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagemaker_session):
7229+
# with search api
7230+
sagemaker_session.sagemaker_client.search.return_value = {
7231+
"Results": [
7232+
{
7233+
"ModelPackageGroup": {
7234+
"ModelPackageGroupName": "mock-mpg",
7235+
"ModelPackageGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package-group/mock-mpg",
7236+
}
7237+
}
7238+
]
7239+
}
7240+
sagemaker_session.create_model_package_from_containers(
7241+
source_uri="mock-source-uri", model_package_group_name="mock-mpg"
7242+
)
7243+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7244+
sagemaker_session.create_model_package_from_containers(
7245+
source_uri="mock-source-uri",
7246+
model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg",
7247+
)
7248+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7249+
sagemaker_session.sagemaker_client.search.return_value = {"Results": []}
7250+
sagemaker_session.create_model_package_from_containers(
7251+
source_uri="mock-source-uri", model_package_group_name="mock-mpg"
7252+
)
7253+
sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with(
7254+
ModelPackageGroupName="mock-mpg"
7255+
)

0 commit comments

Comments
 (0)