Skip to content

Commit 2b2c9d8

Browse files
pintaoz-awspintaoz
authored andcommitted
Fix error when there is no session to call _create_model_request() (aws#5062)
* Fix error when there is no session to call _create_model_request() * Fix codestyle --------- Co-authored-by: pintaoz <[email protected]>
1 parent d7b8c08 commit 2b2c9d8

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/sagemaker/pipeline.py

+15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import sagemaker
1919
from sagemaker import ModelMetrics, Model
20+
from sagemaker import local
21+
from sagemaker import session
2022
from sagemaker.config import (
2123
ENDPOINT_CONFIG_KMS_KEY_ID_PATH,
2224
MODEL_VPC_CONFIG_PATH,
@@ -560,3 +562,16 @@ def delete_model(self):
560562
raise ValueError("The SageMaker model must be created before attempting to delete.")
561563

562564
self.sagemaker_session.delete_model(self.name)
565+
566+
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
567+
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
568+
569+
The type of session object is determined by the instance type.
570+
"""
571+
if self.sagemaker_session:
572+
return
573+
574+
if instance_type in ("local", "local_gpu"):
575+
self.sagemaker_session = local.LocalSession(sagemaker_config=self._sagemaker_config)
576+
else:
577+
self.sagemaker_session = session.Session(sagemaker_config=self._sagemaker_config)

src/sagemaker/workflow/steps.py

+2
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,7 @@ def arguments(self) -> RequestType:
645645
request_dict = self.step_args
646646
else:
647647
if isinstance(self.model, PipelineModel):
648+
self.model._init_sagemaker_session_if_does_not_exist()
648649
request_dict = self.model.sagemaker_session._create_model_request(
649650
name="",
650651
role=self.model.role,
@@ -653,6 +654,7 @@ def arguments(self) -> RequestType:
653654
enable_network_isolation=self.model.enable_network_isolation,
654655
)
655656
else:
657+
self.model._init_sagemaker_session_if_does_not_exist()
656658
request_dict = self.model.sagemaker_session._create_model_request(
657659
name="",
658660
role=self.model.role,

0 commit comments

Comments
 (0)