Skip to content

Commit 6df69e3

Browse files
malav-shastriMalav Shastri
and
Malav Shastri
authored
fix: alt configs model deployment and training issues (#4833)
* fix: alt configs training * pass hub_arn to verify_model_region_and_return_specs * fix training job with alt configs and telemetry changes * fix linting issues * address comments * address Pylint * address linter issues * Add unit tests for walk_and_apply_json changes --------- Co-authored-by: Malav Shastri <[email protected]>
1 parent 686d6f0 commit 6df69e3

File tree

6 files changed

+102
-23
lines changed

6 files changed

+102
-23
lines changed

src/sagemaker/jumpstart/artifacts/metric_definitions.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,17 @@ def _retrieve_default_training_metric_definitions(
9696
else []
9797
)
9898

99-
instance_specific_metric_name: str
100-
for instance_specific_metric_definition in instance_specific_metric_definitions:
101-
instance_specific_metric_name = instance_specific_metric_definition["Name"]
102-
default_metric_definitions = list(
103-
filter(
104-
lambda metric_definition: metric_definition["Name"]
105-
!= instance_specific_metric_name,
106-
default_metric_definitions,
99+
if instance_specific_metric_definitions:
100+
instance_specific_metric_name: str
101+
for instance_specific_metric_definition in instance_specific_metric_definitions:
102+
instance_specific_metric_name = instance_specific_metric_definition["Name"]
103+
default_metric_definitions = list(
104+
filter(
105+
lambda metric_definition: metric_definition["Name"]
106+
!= instance_specific_metric_name,
107+
default_metric_definitions,
108+
)
107109
)
108-
)
109-
default_metric_definitions.append(instance_specific_metric_definition)
110+
default_metric_definitions.append(instance_specific_metric_definition)
110111

111112
return default_metric_definitions

src/sagemaker/jumpstart/factory/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def _add_instance_type_to_kwargs(
259259
sagemaker_session=kwargs.sagemaker_session,
260260
model_type=kwargs.model_type,
261261
config_name=kwargs.config_name,
262+
hub_arn=kwargs.hub_arn,
262263
)
263264

264265
if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs:
@@ -780,6 +781,7 @@ def _add_config_name_to_deploy_kwargs(
780781
sagemaker_session=temp_session,
781782
model_type=kwargs.model_type,
782783
config_name=kwargs.config_name,
784+
hub_arn=kwargs.hub_arn,
783785
)
784786
default_config_name = _select_inference_config_from_training_config(
785787
specs=specs, training_config_name=training_config_name

src/sagemaker/jumpstart/hub/parser_utils.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,20 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
# pylint: skip-file
1314
"""This module contains utilities related to SageMaker JumpStart Hub."""
1415
from __future__ import absolute_import
1516

1617
import re
17-
from typing import Any, Dict
18+
from typing import Any, Dict, List, Optional
1819

1920

2021
def camel_to_snake(camel_case_string: str) -> str:
2122
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
2223
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string)
24+
if "-" in snake_case_string:
25+
# remove any hyphen from the string for accurate conversion.
26+
snake_case_string = snake_case_string.replace("-", "")
2327
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower()
2428

2529

@@ -29,20 +33,29 @@ def snake_to_upper_camel(snake_case_string: str) -> str:
2933
return upper_camel_case_string
3034

3135

32-
def walk_and_apply_json(json_obj: Dict[Any, Any], apply) -> Dict[Any, Any]:
33-
"""Recursively walks a json object and applies a given function to the keys."""
36+
def walk_and_apply_json(
37+
json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"]
38+
) -> Dict[Any, Any]:
39+
"""Recursively walks a json object and applies a given function to the keys.
40+
41+
stop_keys (Optional[list[str]]): List of field keys that should stop the application function.
42+
Any children of these keys will not have the application function applied to them.
43+
"""
3444

3545
def _walk_and_apply_json(json_obj, new):
3646
if isinstance(json_obj, dict) and isinstance(new, dict):
3747
for key, value in json_obj.items():
3848
new_key = apply(key)
39-
if isinstance(value, dict):
40-
new[new_key] = {}
41-
_walk_and_apply_json(value, new=new[new_key])
42-
elif isinstance(value, list):
43-
new[new_key] = []
44-
for item in value:
45-
_walk_and_apply_json(item, new=new[new_key])
49+
if (stop_keys and new_key not in stop_keys) or stop_keys is None:
50+
if isinstance(value, dict):
51+
new[new_key] = {}
52+
_walk_and_apply_json(value, new=new[new_key])
53+
elif isinstance(value, list):
54+
new[new_key] = []
55+
for item in value:
56+
_walk_and_apply_json(item, new=new[new_key])
57+
else:
58+
new[new_key] = value
4659
else:
4760
new[new_key] = value
4861
elif isinstance(json_obj, dict) and isinstance(new, list):
@@ -51,6 +64,8 @@ def _walk_and_apply_json(json_obj, new):
5164
new.update(json_obj)
5265
elif isinstance(json_obj, list) and isinstance(new, list):
5366
new.append(json_obj)
67+
elif isinstance(json_obj, str) and isinstance(new, list):
68+
new.append(json_obj)
5469
return new
5570

5671
return _walk_and_apply_json(json_obj, new={})

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False):
11741174
spec (Dict[str, Any]): Dictionary representation of training config ranking.
11751175
"""
11761176
if is_hub_content:
1177-
spec = {camel_to_snake(key): val for key, val in spec.items()}
1177+
spec = walk_and_apply_json(spec, camel_to_snake)
11781178
self.from_json(spec)
11791179

11801180
def from_json(self, json_obj: Dict[str, Any]) -> None:
@@ -1400,7 +1400,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
14001400

14011401
if self.training_supported:
14021402
if self._is_hub_content:
1403-
self.training_ecr_uri: Optional[str] = json_obj["training_ecr_uri"]
1403+
self.training_ecr_uri: Optional[str] = json_obj.get("training_ecr_uri")
14041404
self._non_serializable_slots.append("training_ecr_specs")
14051405
else:
14061406
self.training_ecr_specs: Optional[JumpStartECRSpecs] = (

src/sagemaker/jumpstart/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from sagemaker.jumpstart import constants, enums
3535
from sagemaker.jumpstart import accessors
36+
from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel
3637
from sagemaker.s3 import parse_s3_url
3738
from sagemaker.jumpstart.exceptions import (
3839
DeprecatedJumpStartModelError,
@@ -1103,6 +1104,17 @@ def get_jumpstart_configs(
11031104
metadata_configs.config_rankings.get("overall").rankings if metadata_configs else []
11041105
)
11051106

1107+
if hub_arn:
1108+
return (
1109+
{
1110+
config_name: metadata_configs.configs[
1111+
camel_to_snake(snake_to_upper_camel(config_name))
1112+
]
1113+
for config_name in config_names
1114+
}
1115+
if metadata_configs
1116+
else {}
1117+
)
11061118
return (
11071119
{config_name: metadata_configs.configs[config_name] for config_name in config_names}
11081120
if metadata_configs

tests/unit/sagemaker/jumpstart/hub/test_utils.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from unittest.mock import patch, Mock
1616
from sagemaker.jumpstart.types import HubArnExtractedInfo
1717
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
18-
from sagemaker.jumpstart.hub import utils
18+
from sagemaker.jumpstart.hub import parser_utils, utils
1919

2020

2121
def test_get_info_from_hub_resource_arn():
@@ -254,3 +254,52 @@ def test_get_hub_model_version_wildcard_char(mock_session):
254254
)
255255

256256
assert result == "2.0.0"
257+
258+
259+
def test_walk_and_apply_json():
260+
test_json = {
261+
"CamelCaseKey": "value",
262+
"CamelCaseObjectKey": {
263+
"CamelCaseObjectChildOne": "value1",
264+
"CamelCaseObjectChildTwo": "value2",
265+
},
266+
"IgnoreMyChildren": {"ShouldNotBeTouchedOne": "const1", "ShouldNotBeTouchedTwo": "const2"},
267+
"ShouldNotIgnoreMyChildren": {"NopeNope": "no"},
268+
}
269+
270+
result = parser_utils.walk_and_apply_json(
271+
test_json, parser_utils.camel_to_snake, ["ignore_my_children"]
272+
)
273+
assert result == {
274+
"camel_case_key": "value",
275+
"camel_case_object_key": {
276+
"camel_case_object_child_one": "value1",
277+
"camel_case_object_child_two": "value2",
278+
},
279+
"ignore_my_children": {
280+
"ShouldNotBeTouchedOne": "const1",
281+
"ShouldNotBeTouchedTwo": "const2",
282+
},
283+
"should_not_ignore_my_children": {"nope_nope": "no"},
284+
}
285+
286+
287+
def test_walk_and_apply_json_no_stop():
288+
test_json = {
289+
"CamelCaseKey": "value",
290+
"CamelCaseObjectKey": {
291+
"CamelCaseObjectChildOne": "value1",
292+
"CamelCaseObjectChildTwo": "value2",
293+
},
294+
"CamelCaseObjectListKey": {"instance.ml.type.xlarge": [{"ShouldChangeMe": "string"}]},
295+
}
296+
297+
result = parser_utils.walk_and_apply_json(test_json, parser_utils.camel_to_snake)
298+
assert result == {
299+
"camel_case_key": "value",
300+
"camel_case_object_key": {
301+
"camel_case_object_child_one": "value1",
302+
"camel_case_object_child_two": "value2",
303+
},
304+
"camel_case_object_list_key": {"instance.ml.type.xlarge": [{"should_change_me": "string"}]},
305+
}

0 commit comments

Comments
 (0)