Skip to content

Commit 17f5ab5

Browse files
beniericTritin Truong
authored andcommitted
Nova training support
* feature: Added Amazon Nova training support for ModelTrainer and Estimator
1 parent 61d043f commit 17f5ab5

File tree

10 files changed

+1985
-169
lines changed

10 files changed

+1985
-169
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ src/sagemaker/modules/train/container_drivers/sourcecode.json
3737
src/sagemaker/modules/train/container_drivers/distributed.json
3838
tests/data/**/_repack_model.py
3939
tests/data/experiment/sagemaker-dev-1.0.tar.gz
40-
src/sagemaker/serve/tmp_workspace
40+
src/sagemaker/serve/tmp_workspace
41+
test-examples

src/sagemaker/estimator.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,30 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A
905905
}
906906
return hyperparameters
907907

908+
@staticmethod
909+
def _nova_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:
910+
"""Applies JSON encoding for Nova job hyperparameters, preserving string values.
911+
912+
For Nova jobs, string values should not be JSON-encoded.
913+
914+
Args:
915+
hyperparameters (dict): Dictionary of hyperparameters.
916+
917+
Returns:
918+
dict: Dictionary with encoded hyperparameters.
919+
"""
920+
current_hyperparameters = hyperparameters
921+
if current_hyperparameters is not None:
922+
hyperparameters = {}
923+
for k, v in current_hyperparameters.items():
924+
if is_pipeline_variable(v):
925+
hyperparameters[str(k)] = v.to_string()
926+
elif isinstance(v, str):
927+
hyperparameters[str(k)] = v
928+
else:
929+
hyperparameters[str(k)] = json.dumps(v)
930+
return hyperparameters
931+
908932
def _prepare_for_training(self, job_name=None):
909933
"""Set any values in the estimator that need to be set before training.
910934
@@ -938,7 +962,11 @@ def _prepare_for_training(self, job_name=None):
938962
self.source_dir = updated_paths["source_dir"]
939963
self.dependencies = updated_paths["dependencies"]
940964

941-
if self.source_dir or self.entry_point or self.dependencies:
965+
if (
966+
self.source_dir
967+
or self.entry_point
968+
or (self.dependencies and len(self.dependencies) > 0)
969+
):
942970
# validate source dir will raise a ValueError if there is something wrong with
943971
# the source directory. We are intentionally not handling it because this is a
944972
# critical error.
@@ -3579,7 +3607,11 @@ def __init__(
35793607
git_config=git_config,
35803608
enable_network_isolation=enable_network_isolation,
35813609
)
3582-
if not is_pipeline_variable(entry_point) and entry_point.startswith("s3://"):
3610+
if (
3611+
not is_pipeline_variable(entry_point)
3612+
and entry_point is not None
3613+
and entry_point.startswith("s3://")
3614+
):
35833615
raise ValueError(
35843616
"Invalid entry point script: {}. Must be a path to a local file.".format(
35853617
entry_point
@@ -3599,6 +3631,7 @@ def __init__(
35993631
self.checkpoint_s3_uri = checkpoint_s3_uri
36003632
self.checkpoint_local_path = checkpoint_local_path
36013633
self.enable_sagemaker_metrics = enable_sagemaker_metrics
3634+
self.is_nova_job = kwargs.get("is_nova_job", False)
36023635

36033636
def _prepare_for_training(self, job_name=None):
36043637
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
@@ -3713,7 +3746,10 @@ def _model_entry_point(self):
37133746

37143747
def set_hyperparameters(self, **kwargs):
37153748
"""Escapes the dict argument as JSON, updates the private hyperparameter attribute."""
3716-
self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(kwargs))
3749+
if self.is_nova_job:
3750+
self._hyperparameters.update(EstimatorBase._nova_encode_hyperparameters(kwargs))
3751+
else:
3752+
self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(kwargs))
37173753

37183754
def hyperparameters(self):
37193755
"""Returns the hyperparameters as a dictionary to use for training.
@@ -3724,7 +3760,10 @@ def hyperparameters(self):
37243760
Returns:
37253761
dict[str, str]: The hyperparameters.
37263762
"""
3727-
return EstimatorBase._json_encode_hyperparameters(self._hyperparameters)
3763+
if self.is_nova_job:
3764+
return EstimatorBase._nova_encode_hyperparameters(self._hyperparameters)
3765+
else:
3766+
return EstimatorBase._json_encode_hyperparameters(self._hyperparameters)
37283767

37293768
@classmethod
37303769
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ def validate_torch_distributed_distribution(
10631063
)
10641064

10651065
# Check entry point type
1066-
if not entry_point.endswith(".py"):
1066+
if entry_point is not None and not entry_point.endswith(".py"):
10671067
err_msg += (
10681068
"Unsupported entry point type for the distribution torch_distributed.\n"
10691069
"Only python programs (*.py) are supported."

src/sagemaker/modules/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
os.path.dirname(os.path.abspath(__file__)), "train/container_drivers"
2626
)
2727

28+
SM_RECIPE = "recipe"
29+
SM_RECIPE_YAML = "recipe.yaml"
30+
SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}"
31+
2832
SOURCE_CODE_JSON = "sourcecode.json"
2933
DISTRIBUTED_JSON = "distributed.json"
3034
TRAIN_SCRIPT = "sm_train.sh"

src/sagemaker/modules/train/model_trainer.py

Lines changed: 83 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@
8585
SM_CODE_CONTAINER_PATH,
8686
SM_DRIVERS,
8787
SM_DRIVERS_LOCAL_PATH,
88+
SM_RECIPE,
89+
SM_RECIPE_YAML,
90+
SM_RECIPE_CONTAINER_PATH,
8891
TRAIN_SCRIPT,
8992
DEFAULT_CONTAINER_ENTRYPOINT,
9093
DEFAULT_CONTAINER_ARGUMENTS,
@@ -100,7 +103,12 @@
100103
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
101104
from sagemaker.telemetry.constants import Feature
102105
from sagemaker.modules import logger
103-
from sagemaker.modules.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type
106+
from sagemaker.modules.train.sm_recipes.utils import (
107+
_get_args_from_recipe,
108+
_determine_device_type,
109+
_is_nova_recipe,
110+
_load_base_recipe,
111+
)
104112

105113

106114
class Mode(Enum):
@@ -242,6 +250,7 @@ class ModelTrainer(BaseModel):
242250
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
243251
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)
244252

253+
_is_nova_recipe: Optional[bool] = PrivateAttr(default=None)
245254
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
246255

247256
CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [
@@ -449,6 +458,33 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
449458
+ "Must be a valid file within the 'source_dir'.",
450459
)
451460

461+
@staticmethod
462+
def _validate_and_load_hyperparameters_file(hyperparameters_file: str) -> Dict[str, Any]:
463+
"""Validate the hyperparameters file."""
464+
if not os.path.exists(hyperparameters_file):
465+
raise ValueError(f"Hyperparameters file not found: {hyperparameters_file}")
466+
logger.info(f"Loading hyperparameters from file: {hyperparameters_file}")
467+
with open(hyperparameters_file, "r") as f:
468+
contents = f.read()
469+
try:
470+
hyperparameters = json.loads(contents)
471+
logger.debug("Hyperparameters loaded as JSON")
472+
return hyperparameters
473+
except json.JSONDecodeError:
474+
try:
475+
logger.info(f"contents: {contents}")
476+
hyperparameters = yaml.safe_load(contents)
477+
if not isinstance(hyperparameters, dict):
478+
raise ValueError("YAML contents must be a valid mapping")
479+
logger.info(f"hyperparameters: {hyperparameters}")
480+
logger.debug("Hyperparameters loaded as YAML")
481+
return hyperparameters
482+
except (yaml.YAMLError, ValueError):
483+
raise ValueError(
484+
f"Invalid hyperparameters file: {hyperparameters_file}. "
485+
"Must be a valid JSON or YAML file."
486+
)
487+
452488
def model_post_init(self, __context: Any):
453489
"""Post init method to perform custom validation and set default values."""
454490
self._validate_training_image_and_algorithm_name(self.training_image, self.algorithm_name)
@@ -510,27 +546,9 @@ def model_post_init(self, __context: Any):
510546
)
511547

512548
if self.hyperparameters and isinstance(self.hyperparameters, str):
513-
if not os.path.exists(self.hyperparameters):
514-
raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}")
515-
logger.info(f"Loading hyperparameters from file: {self.hyperparameters}")
516-
with open(self.hyperparameters, "r") as f:
517-
contents = f.read()
518-
try:
519-
self.hyperparameters = json.loads(contents)
520-
logger.debug("Hyperparameters loaded as JSON")
521-
except json.JSONDecodeError:
522-
try:
523-
logger.info(f"contents: {contents}")
524-
self.hyperparameters = yaml.safe_load(contents)
525-
if not isinstance(self.hyperparameters, dict):
526-
raise ValueError("YAML contents must be a valid mapping")
527-
logger.info(f"hyperparameters: {self.hyperparameters}")
528-
logger.debug("Hyperparameters loaded as YAML")
529-
except (yaml.YAMLError, ValueError):
530-
raise ValueError(
531-
f"Invalid hyperparameters file: {self.hyperparameters}. "
532-
"Must be a valid JSON or YAML file."
533-
)
549+
self.hyperparameters = self._validate_and_load_hyperparameters_file(
550+
self.hyperparameters
551+
)
534552

535553
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB:
536554
if self.output_data_config is None:
@@ -613,6 +631,22 @@ def train(
613631

614632
final_input_data_config = list(existing_channels.values()) + new_channels
615633

634+
if self._is_nova_recipe:
635+
for input_data in final_input_data_config:
636+
if input_data.channel_name == SM_RECIPE:
637+
raise ValueError(
638+
"Cannot use reserved channel name 'recipe' as an input channel name "
639+
" for Nova Recipe"
640+
)
641+
recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML)
642+
recipe_channel = self.create_input_data_channel(
643+
channel_name=SM_RECIPE,
644+
data_source=recipe_file_path,
645+
key_prefix=input_data_key_prefix,
646+
)
647+
final_input_data_config.append(recipe_channel)
648+
self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH})
649+
616650
if final_input_data_config:
617651
final_input_data_config = self._get_input_data_config(
618652
final_input_data_config, input_data_key_prefix
@@ -1005,6 +1039,7 @@ def from_recipe(
10051039
checkpoint_config: Optional[shapes.CheckpointConfig] = None,
10061040
training_input_mode: Optional[str] = "File",
10071041
environment: Optional[Dict[str, str]] = None,
1042+
hyperparameters: Optional[Union[Dict[str, Any], str]] = {},
10081043
tags: Optional[List[Tag]] = None,
10091044
sagemaker_session: Optional[Session] = None,
10101045
role: Optional[str] = None,
@@ -1101,14 +1136,21 @@ def from_recipe(
11011136
"""
11021137
if compute.instance_type is None:
11031138
raise ValueError(
1104-
"Must set ``instance_type`` in compute_config when using training recipes."
1139+
"Must set ``instance_type`` in ``compute`` input when using training recipes."
11051140
)
11061141
device_type = _determine_device_type(compute.instance_type)
1107-
if device_type == "cpu":
1142+
recipe = _load_base_recipe(
1143+
training_recipe=training_recipe, recipe_overrides=recipe_overrides
1144+
)
1145+
is_nova = _is_nova_recipe(recipe=recipe)
1146+
1147+
if device_type == "cpu" and not is_nova:
11081148
raise ValueError(
1109-
"Training recipes are not supported for CPU instances. "
1149+
"Training recipe is not supported for CPU instances. "
11101150
+ "Please provide a GPU or Tranium instance type."
11111151
)
1152+
if training_image is None and is_nova:
1153+
raise ValueError("training_image must be provided when using recipe for Nova.")
11121154

11131155
if training_image_config and training_image is None:
11141156
raise ValueError("training_image must be provided when using training_image_config.")
@@ -1126,15 +1168,27 @@ def from_recipe(
11261168
# - distributed
11271169
# - compute
11281170
# - hyperparameters
1129-
model_trainer_args, recipe_train_dir = _get_args_from_recipe(
1130-
training_recipe=training_recipe,
1171+
model_trainer_args, tmp_dir = _get_args_from_recipe(
1172+
training_recipe=recipe,
11311173
recipe_overrides=recipe_overrides,
11321174
requirements=requirements,
11331175
compute=compute,
11341176
region_name=sagemaker_session.boto_region_name,
1177+
role=role,
11351178
)
11361179
if training_image is not None:
11371180
model_trainer_args["training_image"] = training_image
1181+
if hyperparameters and not is_nova:
1182+
logger.warning(
1183+
"Hyperparameters are not supported for general training recipes. "
1184+
+ "Ignoring hyperparameters input."
1185+
)
1186+
if is_nova:
1187+
if hyperparameters and isinstance(hyperparameters, str):
1188+
hyperparameters = cls._validate_and_load_hyperparameters_file(hyperparameters)
1189+
model_trainer_args["hyperparameters"].update(hyperparameters)
1190+
elif hyperparameters and isinstance(hyperparameters, dict):
1191+
model_trainer_args["hyperparameters"].update(hyperparameters)
11381192

11391193
model_trainer = cls(
11401194
sagemaker_session=sagemaker_session,
@@ -1151,8 +1205,8 @@ def from_recipe(
11511205
tags=tags,
11521206
**model_trainer_args,
11531207
)
1154-
1155-
model_trainer._temp_recipe_train_dir = recipe_train_dir
1208+
model_trainer._is_nova_recipe = is_nova
1209+
model_trainer._temp_recipe_train_dir = tmp_dir
11561210
return model_trainer
11571211

11581212
def with_tensorboard_output_config(

0 commit comments

Comments
 (0)