File tree 2 files changed +17
-0
lines changed
2 files changed +17
-0
lines changed Original file line number Diff line number Diff line change 17
17
18
18
import sagemaker
19
19
from sagemaker import ModelMetrics , Model
20
+ from sagemaker import local
21
+ from sagemaker import session
20
22
from sagemaker .config import (
21
23
ENDPOINT_CONFIG_KMS_KEY_ID_PATH ,
22
24
MODEL_VPC_CONFIG_PATH ,
@@ -560,3 +562,16 @@ def delete_model(self):
560
562
raise ValueError ("The SageMaker model must be created before attempting to delete." )
561
563
562
564
self .sagemaker_session .delete_model (self .name )
565
+
566
+ def _init_sagemaker_session_if_does_not_exist (self , instance_type = None ):
567
+ """Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
568
+
569
+ The type of session object is determined by the instance type.
570
+ """
571
+ if self .sagemaker_session :
572
+ return
573
+
574
+ if instance_type in ("local" , "local_gpu" ):
575
+ self .sagemaker_session = local .LocalSession (sagemaker_config = self ._sagemaker_config )
576
+ else :
577
+ self .sagemaker_session = session .Session (sagemaker_config = self ._sagemaker_config )
Original file line number Diff line number Diff line change @@ -645,6 +645,7 @@ def arguments(self) -> RequestType:
645
645
request_dict = self .step_args
646
646
else :
647
647
if isinstance (self .model , PipelineModel ):
648
+ self .model ._init_sagemaker_session_if_does_not_exist ()
648
649
request_dict = self .model .sagemaker_session ._create_model_request (
649
650
name = "" ,
650
651
role = self .model .role ,
@@ -653,6 +654,7 @@ def arguments(self) -> RequestType:
653
654
enable_network_isolation = self .model .enable_network_isolation ,
654
655
)
655
656
else :
657
+ self .model ._init_sagemaker_session_if_does_not_exist ()
656
658
request_dict = self .model .sagemaker_session ._create_model_request (
657
659
name = "" ,
658
660
role = self .model .role ,
You can’t perform that action at this time.
0 commit comments