85
85
SM_CODE_CONTAINER_PATH ,
86
86
SM_DRIVERS ,
87
87
SM_DRIVERS_LOCAL_PATH ,
88
+ SM_RECIPE ,
89
+ SM_RECIPE_YAML ,
90
+ SM_RECIPE_CONTAINER_PATH ,
88
91
TRAIN_SCRIPT ,
89
92
DEFAULT_CONTAINER_ENTRYPOINT ,
90
93
DEFAULT_CONTAINER_ARGUMENTS ,
100
103
from sagemaker .telemetry .telemetry_logging import _telemetry_emitter
101
104
from sagemaker .telemetry .constants import Feature
102
105
from sagemaker .modules import logger
103
- from sagemaker .modules .train .sm_recipes .utils import _get_args_from_recipe , _determine_device_type
106
+ from sagemaker .modules .train .sm_recipes .utils import (
107
+ _get_args_from_recipe ,
108
+ _determine_device_type ,
109
+ _is_nova_recipe ,
110
+ _load_base_recipe ,
111
+ )
104
112
105
113
106
114
class Mode (Enum ):
@@ -242,6 +250,7 @@ class ModelTrainer(BaseModel):
242
250
_remote_debug_config : Optional [RemoteDebugConfig ] = PrivateAttr (default = None )
243
251
_metric_definitions : Optional [List [MetricDefinition ]] = PrivateAttr (default = None )
244
252
253
+ _is_nova_recipe : Optional [bool ] = PrivateAttr (default = None )
245
254
_temp_recipe_train_dir : Optional [TemporaryDirectory ] = PrivateAttr (default = None )
246
255
247
256
CONFIGURABLE_ATTRIBUTES : ClassVar [List [str ]] = [
@@ -449,6 +458,33 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
449
458
+ "Must be a valid file within the 'source_dir'." ,
450
459
)
451
460
461
+ @staticmethod
462
+ def _validate_and_load_hyperparameters_file (hyperparameters_file : str ) -> Dict [str , Any ]:
463
+ """Validate the hyperparameters file."""
464
+ if not os .path .exists (hyperparameters_file ):
465
+ raise ValueError (f"Hyperparameters file not found: { hyperparameters_file } " )
466
+ logger .info (f"Loading hyperparameters from file: { hyperparameters_file } " )
467
+ with open (hyperparameters_file , "r" ) as f :
468
+ contents = f .read ()
469
+ try :
470
+ hyperparameters = json .loads (contents )
471
+ logger .debug ("Hyperparameters loaded as JSON" )
472
+ return hyperparameters
473
+ except json .JSONDecodeError :
474
+ try :
475
+ logger .info (f"contents: { contents } " )
476
+ hyperparameters = yaml .safe_load (contents )
477
+ if not isinstance (hyperparameters , dict ):
478
+ raise ValueError ("YAML contents must be a valid mapping" )
479
+ logger .info (f"hyperparameters: { hyperparameters } " )
480
+ logger .debug ("Hyperparameters loaded as YAML" )
481
+ return hyperparameters
482
+ except (yaml .YAMLError , ValueError ):
483
+ raise ValueError (
484
+ f"Invalid hyperparameters file: { hyperparameters_file } . "
485
+ "Must be a valid JSON or YAML file."
486
+ )
487
+
452
488
def model_post_init (self , __context : Any ):
453
489
"""Post init method to perform custom validation and set default values."""
454
490
self ._validate_training_image_and_algorithm_name (self .training_image , self .algorithm_name )
@@ -510,27 +546,9 @@ def model_post_init(self, __context: Any):
510
546
)
511
547
512
548
if self .hyperparameters and isinstance (self .hyperparameters , str ):
513
- if not os .path .exists (self .hyperparameters ):
514
- raise ValueError (f"Hyperparameters file not found: { self .hyperparameters } " )
515
- logger .info (f"Loading hyperparameters from file: { self .hyperparameters } " )
516
- with open (self .hyperparameters , "r" ) as f :
517
- contents = f .read ()
518
- try :
519
- self .hyperparameters = json .loads (contents )
520
- logger .debug ("Hyperparameters loaded as JSON" )
521
- except json .JSONDecodeError :
522
- try :
523
- logger .info (f"contents: { contents } " )
524
- self .hyperparameters = yaml .safe_load (contents )
525
- if not isinstance (self .hyperparameters , dict ):
526
- raise ValueError ("YAML contents must be a valid mapping" )
527
- logger .info (f"hyperparameters: { self .hyperparameters } " )
528
- logger .debug ("Hyperparameters loaded as YAML" )
529
- except (yaml .YAMLError , ValueError ):
530
- raise ValueError (
531
- f"Invalid hyperparameters file: { self .hyperparameters } . "
532
- "Must be a valid JSON or YAML file."
533
- )
549
+ self .hyperparameters = self ._validate_and_load_hyperparameters_file (
550
+ self .hyperparameters
551
+ )
534
552
535
553
if self .training_mode == Mode .SAGEMAKER_TRAINING_JOB :
536
554
if self .output_data_config is None :
@@ -613,6 +631,22 @@ def train(
613
631
614
632
final_input_data_config = list (existing_channels .values ()) + new_channels
615
633
634
+ if self ._is_nova_recipe :
635
+ for input_data in final_input_data_config :
636
+ if input_data .channel_name == SM_RECIPE :
637
+ raise ValueError (
638
+ "Cannot use reserved channel name 'recipe' as an input channel name "
639
+ " for Nova Recipe"
640
+ )
641
+ recipe_file_path = os .path .join (self ._temp_recipe_train_dir .name , SM_RECIPE_YAML )
642
+ recipe_channel = self .create_input_data_channel (
643
+ channel_name = SM_RECIPE ,
644
+ data_source = recipe_file_path ,
645
+ key_prefix = input_data_key_prefix ,
646
+ )
647
+ final_input_data_config .append (recipe_channel )
648
+ self .hyperparameters .update ({"sagemaker_recipe_local_path" : SM_RECIPE_CONTAINER_PATH })
649
+
616
650
if final_input_data_config :
617
651
final_input_data_config = self ._get_input_data_config (
618
652
final_input_data_config , input_data_key_prefix
@@ -1005,6 +1039,7 @@ def from_recipe(
1005
1039
checkpoint_config : Optional [shapes .CheckpointConfig ] = None ,
1006
1040
training_input_mode : Optional [str ] = "File" ,
1007
1041
environment : Optional [Dict [str , str ]] = None ,
1042
+ hyperparameters : Optional [Union [Dict [str , Any ], str ]] = {},
1008
1043
tags : Optional [List [Tag ]] = None ,
1009
1044
sagemaker_session : Optional [Session ] = None ,
1010
1045
role : Optional [str ] = None ,
@@ -1101,14 +1136,21 @@ def from_recipe(
1101
1136
"""
1102
1137
if compute .instance_type is None :
1103
1138
raise ValueError (
1104
- "Must set ``instance_type`` in compute_config when using training recipes."
1139
+ "Must set ``instance_type`` in ``compute`` input when using training recipes."
1105
1140
)
1106
1141
device_type = _determine_device_type (compute .instance_type )
1107
- if device_type == "cpu" :
1142
+ recipe = _load_base_recipe (
1143
+ training_recipe = training_recipe , recipe_overrides = recipe_overrides
1144
+ )
1145
+ is_nova = _is_nova_recipe (recipe = recipe )
1146
+
1147
+ if device_type == "cpu" and not is_nova :
1108
1148
raise ValueError (
1109
- "Training recipes are not supported for CPU instances. "
1149
+ "Training recipe is not supported for CPU instances. "
1110
1150
+ "Please provide a GPU or Tranium instance type."
1111
1151
)
1152
+ if training_image is None and is_nova :
1153
+ raise ValueError ("training_image must be provided when using recipe for Nova." )
1112
1154
1113
1155
if training_image_config and training_image is None :
1114
1156
raise ValueError ("training_image must be provided when using training_image_config." )
@@ -1126,15 +1168,27 @@ def from_recipe(
1126
1168
# - distributed
1127
1169
# - compute
1128
1170
# - hyperparameters
1129
- model_trainer_args , recipe_train_dir = _get_args_from_recipe (
1130
- training_recipe = training_recipe ,
1171
+ model_trainer_args , tmp_dir = _get_args_from_recipe (
1172
+ training_recipe = recipe ,
1131
1173
recipe_overrides = recipe_overrides ,
1132
1174
requirements = requirements ,
1133
1175
compute = compute ,
1134
1176
region_name = sagemaker_session .boto_region_name ,
1177
+ role = role ,
1135
1178
)
1136
1179
if training_image is not None :
1137
1180
model_trainer_args ["training_image" ] = training_image
1181
+ if hyperparameters and not is_nova :
1182
+ logger .warning (
1183
+ "Hyperparameters are not supported for general training recipes. "
1184
+ + "Ignoring hyperparameters input."
1185
+ )
1186
+ if is_nova :
1187
+ if hyperparameters and isinstance (hyperparameters , str ):
1188
+ hyperparameters = cls ._validate_and_load_hyperparameters_file (hyperparameters )
1189
+ model_trainer_args ["hyperparameters" ].update (hyperparameters )
1190
+ elif hyperparameters and isinstance (hyperparameters , dict ):
1191
+ model_trainer_args ["hyperparameters" ].update (hyperparameters )
1138
1192
1139
1193
model_trainer = cls (
1140
1194
sagemaker_session = sagemaker_session ,
@@ -1151,8 +1205,8 @@ def from_recipe(
1151
1205
tags = tags ,
1152
1206
** model_trainer_args ,
1153
1207
)
1154
-
1155
- model_trainer ._temp_recipe_train_dir = recipe_train_dir
1208
+ model_trainer . _is_nova_recipe = is_nova
1209
+ model_trainer ._temp_recipe_train_dir = tmp_dir
1156
1210
return model_trainer
1157
1211
1158
1212
def with_tensorboard_output_config (
0 commit comments