@@ -186,6 +186,7 @@ def __init__(
186
186
enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
187
187
enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] = None ,
188
188
training_plan : Optional [Union [str , PipelineVariable ]] = None ,
189
+ instance_placement_config : Optional [Dict ] = None ,
189
190
** kwargs ,
190
191
):
191
192
"""Initialize an ``EstimatorBase`` instance.
@@ -560,6 +561,21 @@ def __init__(
560
561
Specifies whether SessionTagChaining is enabled for the training job.
561
562
training_plan (str or PipelineVariable): Optional.
562
563
Specifies which training plan arn to use for the training job
564
+ instance_placement_config (dict): Optional.
565
+ Specifies UltraServer placement configuration for the training job
566
+
567
+ .. code:: python
568
+
569
+ instance_placement_config={
570
+ "EnableMultipleJobs": True,
571
+ "PlacementSpecifications":[
572
+ {
573
+ "UltraServerId": "ultraserver-1",
574
+ "InstanceCount": "2"
575
+ }
576
+ ]
577
+ }
578
+
563
579
"""
564
580
instance_count = renamed_kwargs (
565
581
"train_instance_count" , "instance_count" , instance_count , kwargs
@@ -813,6 +829,8 @@ def __init__(
813
829
814
830
self .training_plan = training_plan
815
831
832
+ self .instance_placement_config = instance_placement_config
833
+
816
834
# Internal flag
817
835
self ._is_output_path_set_from_default_bucket_and_prefix = False
818
836
@@ -1997,6 +2015,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
1997
2015
if "TrainingPlanArn" in job_details ["ResourceConfig" ]:
1998
2016
init_params ["training_plan" ] = job_details ["ResourceConfig" ]["TrainingPlanArn" ]
1999
2017
2018
+ if "InstancePlacementConfig" in job_details ["ResourceConfig" ]:
2019
+ init_params ["instance_placement_config" ] = job_details ["ResourceConfig" ][
2020
+ "InstancePlacementConfig"
2021
+ ]
2022
+
2000
2023
has_hps = "HyperParameters" in job_details
2001
2024
init_params ["hyperparameters" ] = job_details ["HyperParameters" ] if has_hps else {}
2002
2025
@@ -2882,6 +2905,7 @@ def __init__(
2882
2905
enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
2883
2906
enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] = None ,
2884
2907
training_plan : Optional [Union [str , PipelineVariable ]] = None ,
2908
+ instance_placement_config : Optional [Dict ] = None ,
2885
2909
** kwargs ,
2886
2910
):
2887
2911
"""Initialize an ``Estimator`` instance.
@@ -3249,6 +3273,20 @@ def __init__(
3249
3273
Specifies whether SessionTagChaining is enabled for the training job
3250
3274
training_plan (str or PipelineVariable): Optional.
3251
3275
Specifies which training plan arn to use for the training job
3276
+ instance_placement_config (dict): Optional.
3277
+ Specifies UltraServer placement configuration for the training job
3278
+
3279
+ .. code:: python
3280
+
3281
+ instance_placement_config={
3282
+ "EnableMultipleJobs": True,
3283
+ "PlacementSpecifications":[
3284
+ {
3285
+ "UltraServerId": "ultraserver-1",
3286
+ "InstanceCount": "2"
3287
+ }
3288
+ ]
3289
+ }
3252
3290
"""
3253
3291
self .image_uri = image_uri
3254
3292
self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
@@ -3303,6 +3341,7 @@ def __init__(
3303
3341
enable_remote_debug = enable_remote_debug ,
3304
3342
enable_session_tag_chaining = enable_session_tag_chaining ,
3305
3343
training_plan = training_plan ,
3344
+ instance_placement_config = instance_placement_config ,
3306
3345
** kwargs ,
3307
3346
)
3308
3347
0 commit comments