Skip to content

Commit 1dcaa7b

Browse files
authored
feat: Enable notebook instances to get presigned url (#4107)
1 parent 498d94d commit 1dcaa7b

File tree

3 files changed

+241
-147
lines changed

3 files changed

+241
-147
lines changed

src/sagemaker/interactive_apps/base_interactive_app.py

Lines changed: 29 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
# Used to store domain and user profile info retrieved from Studio environment.
6060
self._domain_id = None
6161
self._user_profile_name = None
62+
self._in_studio_env = False
6263
self._get_domain_and_user()
6364

6465
def __str__(self):
@@ -70,25 +71,25 @@ def __repr__(self):
7071
return self.__str__()
7172

7273
def _get_domain_and_user(self):
73-
"""Get and validate studio domain id and user profile from studio environment."""
74-
if not self._is_in_studio():
74+
"""Get domain id and user profile from Studio environment.
75+
76+
To verify Studio environment, we check if NOTEBOOK_METADATA_FILE exists
77+
and domain id and user profile name are present in the file.
78+
"""
79+
if not os.path.isfile(NOTEBOOK_METADATA_FILE):
7580
return
7681

7782
try:
7883
with open(NOTEBOOK_METADATA_FILE, "rb") as metadata_file:
7984
metadata = json.loads(metadata_file.read())
80-
if not self._validate_domain_id(
81-
metadata.get("DomainId")
82-
) or not self._validate_user_profile_name(metadata.get("UserProfileName")):
83-
logger.warning(
84-
"NOTEBOOK_METADATA_FILE detected but failed to get valid domain and user"
85-
" from it."
86-
)
87-
return
88-
self._domain_id = metadata.get("DomainId")
89-
self._user_profile_name = metadata.get("UserProfileName")
9085
except OSError as err:
91-
logger.warning("Could not load Studio metadata due to unexpected error. %s", err)
86+
logger.warning("Could not load metadata due to unexpected error. %s", err)
87+
return
88+
89+
if "DomainId" in metadata and "UserProfileName" in metadata:
90+
self._in_studio_env = True
91+
self._domain_id = metadata.get("DomainId")
92+
self._user_profile_name = metadata.get("UserProfileName")
9293

9394
def _get_presigned_url(
9495
self,
@@ -142,10 +143,6 @@ def _get_presigned_url(
142143

143144
return url
144145

145-
def _is_in_studio(self):
146-
"""Check to see if NOTEBOOK_METADATA_FILE exists to verify Studio environment."""
147-
return os.path.isfile(NOTEBOOK_METADATA_FILE)
148-
149146
def _open_url_in_web_browser(self, url: str):
150147
"""Open a URL in the default web browser.
151148
@@ -154,23 +151,6 @@ def _open_url_in_web_browser(self, url: str):
154151
"""
155152
webbrowser.open(url)
156153

157-
def _validate_domain_id(self, domain_id: Optional[str] = None):
158-
"""Validate domain id format.
159-
160-
Args:
161-
domain_id (str): Optional. The domain ID to validate. If one is not supplied,
162-
self._domain_id will be used instead.
163-
Default: ``None``
164-
165-
Returns:
166-
bool: Whether the supplied domain ID is valid.
167-
"""
168-
if domain_id is None:
169-
domain_id = self._domain_id
170-
if domain_id is None or len(domain_id) > 63:
171-
return False
172-
return True
173-
174154
def _validate_job_name(self, job_name: str):
175155
"""Validate training job name format.
176156
@@ -186,30 +166,35 @@ def _validate_job_name(self, job_name: str):
186166
f"Invalid job name. Job name must match regular expression {job_name_regex}"
187167
)
188168

189-
def _validate_user_profile_name(self, user_profile_name: Optional[str] = None):
169+
def _validate_domain_id(self, domain_id: str):
170+
"""Validate domain id format.
171+
172+
Args:
173+
domain_id (str): Required. The domain ID to validate.
174+
175+
Returns:
176+
bool: Whether the supplied domain ID is valid.
177+
"""
178+
if domain_id is None or len(domain_id) > 63:
179+
return False
180+
return True
181+
182+
def _validate_user_profile_name(self, user_profile_name: str):
190183
"""Validate user profile name format.
191184
192185
Args:
193-
user_profile_name (str): Optional. The user profile name to validate. If one is not
194-
supplied, self._user_profile_name will be used instead.
195-
Default: ``None``
186+
user_profile_name (str): Required. The user profile name to validate.
196187
197188
Returns:
198189
bool: Whether the supplied user profile name is valid.
199190
"""
200-
if user_profile_name is None:
201-
user_profile_name = self._user_profile_name
202191
user_profile_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}"
203192
if user_profile_name is None or not re.fullmatch(
204193
user_profile_name_regex, user_profile_name
205194
):
206195
return False
207196
return True
208197

209-
def _validate_domain_and_user(self):
210-
"""Helper function to consolidate validation calls."""
211-
return self._validate_domain_id() and self._validate_user_profile_name()
212-
213198
@abc.abstractmethod
214199
def get_app_url(self):
215200
"""Abstract method to generate a URL to help access the application in Studio.

src/sagemaker/interactive_apps/tensorboard.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -87,36 +87,11 @@ def get_app_url(
8787
if training_job_name is not None:
8888
self._validate_job_name(training_job_name)
8989

90-
if optional_create_presigned_url_kwargs is None:
91-
optional_create_presigned_url_kwargs = {}
92-
93-
if domain_id is not None:
94-
optional_create_presigned_url_kwargs["DomainId"] = domain_id
95-
96-
if user_profile_name is not None:
97-
optional_create_presigned_url_kwargs["UserProfileName"] = user_profile_name
98-
9990
if (
100-
create_presigned_domain_url
101-
and not self._is_in_studio()
102-
and self._validate_domain_id(optional_create_presigned_url_kwargs.get("DomainId"))
103-
and self._validate_user_profile_name(
104-
optional_create_presigned_url_kwargs.get("UserProfileName")
105-
)
91+
self._in_studio_env
92+
and self._validate_domain_id(self._domain_id)
93+
and self._validate_user_profile_name(self._user_profile_name)
10694
):
107-
state_to_encode = None
108-
redirect = "TensorBoard"
109-
110-
if training_job_name is not None:
111-
state_to_encode = (
112-
"/tensorboard/default/data/plugin/sagemaker_data_manager/"
113-
+ f"add_folder_or_job?Redirect=True&Name={training_job_name}"
114-
)
115-
116-
url = self._get_presigned_url(
117-
optional_create_presigned_url_kwargs, redirect, state_to_encode
118-
)
119-
elif self._is_in_studio() and self._validate_domain_and_user():
12095
if domain_id or user_profile_name:
12196
logger.warning(
12297
"Ignoring passed in domain_id and user_profile_name for Studio set values."
@@ -126,15 +101,37 @@ def get_app_url(
126101
+ "sagemaker.aws/tensorboard/default"
127102
)
128103
if training_job_name is not None:
129-
self._validate_job_name(training_job_name)
130104
url += (
131105
"/data/plugin/sagemaker_data_manager/"
132106
+ f"add_folder_or_job?Redirect=True&Name={training_job_name}"
133107
)
134108
else:
135109
url += "/#sagemaker_data_manager"
110+
111+
elif (
112+
not self._in_studio_env
113+
and create_presigned_domain_url
114+
and self._validate_domain_id(domain_id)
115+
and self._validate_user_profile_name(user_profile_name)
116+
):
117+
if optional_create_presigned_url_kwargs is None:
118+
optional_create_presigned_url_kwargs = {}
119+
optional_create_presigned_url_kwargs["DomainId"] = domain_id
120+
optional_create_presigned_url_kwargs["UserProfileName"] = user_profile_name
121+
122+
redirect = "TensorBoard"
123+
state_to_encode = None
124+
if training_job_name is not None:
125+
state_to_encode = (
126+
"/tensorboard/default/data/plugin/sagemaker_data_manager/"
127+
+ f"add_folder_or_job?Redirect=True&Name={training_job_name}"
128+
)
129+
130+
url = self._get_presigned_url(
131+
optional_create_presigned_url_kwargs, redirect, state_to_encode
132+
)
136133
else:
137-
if domain_id or user_profile_name or create_presigned_domain_url:
134+
if not self._in_studio_env and create_presigned_domain_url:
138135
logger.warning(
139136
"A valid domain ID and user profile name were not provided. "
140137
"Providing default landing page URL as a result."

0 commit comments

Comments
 (0)