@@ -45,6 +45,7 @@ def __init__(self):
4545 self .framework = None
4646 self .framework_version = None
4747 self ._is_mlflow_model = False
48+ self .config_name = None
4849
4950 def _deploy_local_endpoint (self , ** kwargs ):
5051 return Mock ()
@@ -816,6 +817,64 @@ def test_build_unsupported_image_uri(self, mock_init):
816817 self .builder ._build_for_jumpstart ()
817818 self .assertIn ("Unsupported" , str (ctx .exception ))
818819
820+ @patch ('sagemaker.core.jumpstart.factory.utils.get_init_kwargs' )
821+ @patch ('sagemaker.serve.model_builder_servers.prepare_djl_js_resources' )
822+ @patch .object (MockModelBuilderServers , '_prepare_for_mode' )
823+ @patch .object (MockModelBuilderServers , '_create_model' )
824+ def test_build_passes_config_name_to_get_init_kwargs (self , mock_create , mock_prepare_mode , mock_djl_res , mock_init ):
825+ """Test that config_name is forwarded to get_init_kwargs."""
826+ mock_init_kwargs = Mock ()
827+ mock_init_kwargs .image_uri = "djl-inference:0.21.0"
828+ mock_init_kwargs .env = {"TEST" : "value" }
829+ mock_init_kwargs .model_data = "s3://bucket/model.tar.gz"
830+ mock_init .return_value = mock_init_kwargs
831+ mock_djl_res .return_value = ({"config" : "value" }, True )
832+ mock_create .return_value = Mock ()
833+ self .builder .mode = Mode .LOCAL_CONTAINER
834+ self .builder .image_uri = None
835+ self .builder .config_name = "lmi-optimized"
836+
837+ self .builder ._build_for_jumpstart ()
838+
839+ mock_init .assert_called_once_with (
840+ model_id = self .builder .model ,
841+ model_version = "*" ,
842+ region = self .builder .region ,
843+ instance_type = self .builder .instance_type ,
844+ tolerate_vulnerable_model = None ,
845+ tolerate_deprecated_model = None ,
846+ config_name = "lmi-optimized" ,
847+ )
848+
849+ @patch ('sagemaker.core.jumpstart.factory.utils.get_init_kwargs' )
850+ @patch ('sagemaker.serve.model_builder_servers.prepare_djl_js_resources' )
851+ @patch .object (MockModelBuilderServers , '_prepare_for_mode' )
852+ @patch .object (MockModelBuilderServers , '_create_model' )
853+ def test_build_passes_none_config_name_when_not_set (self , mock_create , mock_prepare_mode , mock_djl_res , mock_init ):
854+ """Test that config_name defaults to None when not set."""
855+ mock_init_kwargs = Mock ()
856+ mock_init_kwargs .image_uri = "djl-inference:0.21.0"
857+ mock_init_kwargs .env = {}
858+ mock_init_kwargs .model_data = "s3://bucket/model.tar.gz"
859+ mock_init .return_value = mock_init_kwargs
860+ mock_djl_res .return_value = ({"config" : "value" }, True )
861+ mock_create .return_value = Mock ()
862+ self .builder .mode = Mode .LOCAL_CONTAINER
863+ self .builder .image_uri = None
864+ self .builder .config_name = None
865+
866+ self .builder ._build_for_jumpstart ()
867+
868+ mock_init .assert_called_once_with (
869+ model_id = self .builder .model ,
870+ model_version = "*" ,
871+ region = self .builder .region ,
872+ instance_type = self .builder .instance_type ,
873+ tolerate_vulnerable_model = None ,
874+ tolerate_deprecated_model = None ,
875+ config_name = None ,
876+ )
877+
819878 @patch ('sagemaker.core.jumpstart.factory.utils.get_init_kwargs' )
820879 @patch .object (MockModelBuilderServers , '_prepare_for_mode' )
821880 @patch .object (MockModelBuilderServers , '_build_for_djl_jumpstart' )
0 commit comments