Skip to content

Commit 7732ecf

Browse files
authored
fix: resolve alt config resolution for jumpstart models (#5563)
Failing mlops test is unrelated.
1 parent f5f636f commit 7732ecf

File tree

3 files changed

+94
-6
lines changed

3 files changed

+94
-6
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -851,11 +851,12 @@ def _build_for_jumpstart(self) -> Model:
851851
# Get JumpStart model configuration
852852
init_kwargs = get_init_kwargs(
853853
model_id=self.model,
854-
model_version=self.model_version or "*",
854+
model_version=self.model_version or "*",
855855
region=self.region,
856856
instance_type=self.instance_type,
857857
tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None),
858-
tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None)
858+
tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None),
859+
config_name=getattr(self, 'config_name', None),
859860
)
860861

861862
# Configure image URI and environment variables

sagemaker-serve/tests/unit/servers/test_model_builder_servers.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self):
4545
self.framework = None
4646
self.framework_version = None
4747
self._is_mlflow_model = False
48+
self.config_name = None
4849

4950
def _deploy_local_endpoint(self, **kwargs):
5051
return Mock()
@@ -816,6 +817,64 @@ def test_build_unsupported_image_uri(self, mock_init):
816817
self.builder._build_for_jumpstart()
817818
self.assertIn("Unsupported", str(ctx.exception))
818819

820+
@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
821+
@patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources')
822+
@patch.object(MockModelBuilderServers, '_prepare_for_mode')
823+
@patch.object(MockModelBuilderServers, '_create_model')
824+
def test_build_passes_config_name_to_get_init_kwargs(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init):
825+
"""Test that config_name is forwarded to get_init_kwargs."""
826+
mock_init_kwargs = Mock()
827+
mock_init_kwargs.image_uri = "djl-inference:0.21.0"
828+
mock_init_kwargs.env = {"TEST": "value"}
829+
mock_init_kwargs.model_data = "s3://bucket/model.tar.gz"
830+
mock_init.return_value = mock_init_kwargs
831+
mock_djl_res.return_value = ({"config": "value"}, True)
832+
mock_create.return_value = Mock()
833+
self.builder.mode = Mode.LOCAL_CONTAINER
834+
self.builder.image_uri = None
835+
self.builder.config_name = "lmi-optimized"
836+
837+
self.builder._build_for_jumpstart()
838+
839+
mock_init.assert_called_once_with(
840+
model_id=self.builder.model,
841+
model_version="*",
842+
region=self.builder.region,
843+
instance_type=self.builder.instance_type,
844+
tolerate_vulnerable_model=None,
845+
tolerate_deprecated_model=None,
846+
config_name="lmi-optimized",
847+
)
848+
849+
@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
850+
@patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources')
851+
@patch.object(MockModelBuilderServers, '_prepare_for_mode')
852+
@patch.object(MockModelBuilderServers, '_create_model')
853+
def test_build_passes_none_config_name_when_not_set(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init):
854+
"""Test that config_name defaults to None when not set."""
855+
mock_init_kwargs = Mock()
856+
mock_init_kwargs.image_uri = "djl-inference:0.21.0"
857+
mock_init_kwargs.env = {}
858+
mock_init_kwargs.model_data = "s3://bucket/model.tar.gz"
859+
mock_init.return_value = mock_init_kwargs
860+
mock_djl_res.return_value = ({"config": "value"}, True)
861+
mock_create.return_value = Mock()
862+
self.builder.mode = Mode.LOCAL_CONTAINER
863+
self.builder.image_uri = None
864+
self.builder.config_name = None
865+
866+
self.builder._build_for_jumpstart()
867+
868+
mock_init.assert_called_once_with(
869+
model_id=self.builder.model,
870+
model_version="*",
871+
region=self.builder.region,
872+
instance_type=self.builder.instance_type,
873+
tolerate_vulnerable_model=None,
874+
tolerate_deprecated_model=None,
875+
config_name=None,
876+
)
877+
819878
@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
820879
@patch.object(MockModelBuilderServers, '_prepare_for_mode')
821880
@patch.object(MockModelBuilderServers, '_build_for_djl_jumpstart')

sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,24 +347,52 @@ def test_build_for_jumpstart_routes_to_djl(self, mock_prepare, mock_build_djl, m
347347
mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
348348
mock_init_kwargs.env = {}
349349
mock_get_kwargs.return_value = mock_init_kwargs
350-
350+
351351
mock_model = Mock(spec=Model)
352352
mock_build_djl.return_value = mock_model
353-
353+
354354
builder = ModelBuilder(
355355
model="huggingface-llm-falcon-7b",
356356
role_arn=MOCK_ROLE_ARN,
357357
sagemaker_session=self.mock_session,
358358
mode=Mode.SAGEMAKER_ENDPOINT
359359
)
360360
builder._optimizing = False
361-
361+
362362
result = builder._build_for_jumpstart()
363-
363+
364364
self.assertEqual(result, mock_model)
365365
self.assertEqual(builder.model_server, ModelServer.DJL_SERVING)
366366
mock_build_djl.assert_called_once()
367367

368+
@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
369+
@patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart')
370+
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
371+
def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_djl, mock_get_kwargs):
372+
"""Test that config_name is forwarded to get_init_kwargs."""
373+
mock_init_kwargs = Mock()
374+
mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
375+
mock_init_kwargs.env = {}
376+
mock_get_kwargs.return_value = mock_init_kwargs
377+
378+
mock_model = Mock(spec=Model)
379+
mock_build_djl.return_value = mock_model
380+
381+
builder = ModelBuilder(
382+
model="meta-textgeneration-llama-3-3-70b-instruct",
383+
role_arn=MOCK_ROLE_ARN,
384+
sagemaker_session=self.mock_session,
385+
mode=Mode.SAGEMAKER_ENDPOINT
386+
)
387+
builder._optimizing = False
388+
builder.config_name = "lmi-optimized"
389+
390+
builder._build_for_jumpstart()
391+
392+
mock_get_kwargs.assert_called_once()
393+
call_kwargs = mock_get_kwargs.call_args
394+
self.assertEqual(call_kwargs.kwargs.get("config_name") or call_kwargs[1].get("config_name"), "lmi-optimized")
395+
368396
@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
369397
@patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi_jumpstart')
370398
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')

0 commit comments

Comments
 (0)