32
32
REGION = 'us-west-2'
33
33
BUCKET_NAME = 'mybucket'
34
34
EXPANDED_ROLE = 'arn:aws:iam::111111111111:role/ExpandedRole'
35
+ TRAINING_JOB_NAME = 'my-job'
35
36
INPUT_DATA_CONFIG = [
36
37
{
37
38
'ChannelName' : 'a' ,
55
56
]
56
57
HYPERPARAMETERS = {'a' : 1 ,
57
58
'b' : json .dumps ('bee' ),
58
- 'sagemaker_submit_directory' : json .dumps ('s3://my_bucket/code' ),
59
- 'sagemaker_job_name' : json . dumps ( 'my-job' )}
59
+ 'sagemaker_submit_directory' : json .dumps ('s3://my_bucket/code' )}
60
+
60
61
61
62
LOCAL_CODE_HYPERPARAMETERS = {'a' : 1 ,
62
63
'b' : 2 ,
63
- 'sagemaker_submit_directory' : json .dumps ('file:///tmp/code' ),
64
- 'sagemaker_job_name' : json .dumps ('my-job' )}
64
+ 'sagemaker_submit_directory' : json .dumps ('file:///tmp/code' )}
65
65
66
66
67
67
@pytest .fixture ()
@@ -230,7 +230,7 @@ def test_train(_download_folder, _cleanup, popen, _stream_output, LocalSession,
230
230
instance_count = 2
231
231
image = 'my-image'
232
232
sagemaker_container = _SageMakerContainer ('local' , instance_count , image , sagemaker_session = sagemaker_session )
233
- sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS )
233
+ sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS , TRAINING_JOB_NAME )
234
234
235
235
channel_dir = os .path .join (directories [1 ], 'b' )
236
236
download_folder_calls = [call ('my-own-bucket' , 'prefix' , channel_dir )]
@@ -252,13 +252,36 @@ def test_train(_download_folder, _cleanup, popen, _stream_output, LocalSession,
252
252
assert config ['services' ][h ]['image' ] == image
253
253
assert config ['services' ][h ]['command' ] == 'train'
254
254
assert 'AWS_REGION={}' .format (REGION ) in config ['services' ][h ]['environment' ]
255
- assert 'TRAINING_JOB_NAME=my-job' in config ['services' ][h ]['environment' ]
255
+ assert 'TRAINING_JOB_NAME={}' . format ( TRAINING_JOB_NAME ) in config ['services' ][h ]['environment' ]
256
256
257
257
# assert that expected by sagemaker container output directories exist
258
258
assert os .path .exists (os .path .join (sagemaker_container .container_root , 'output' ))
259
259
assert os .path .exists (os .path .join (sagemaker_container .container_root , 'output/data' ))
260
260
261
261
262
+ @patch ('sagemaker.local.local_session.LocalSession' )
263
+ @patch ('sagemaker.local.image._stream_output' )
264
+ @patch ('sagemaker.local.image._SageMakerContainer._cleanup' )
265
+ @patch ('sagemaker.local.image._SageMakerContainer._download_folder' )
266
+ def test_train_with_hyperparameters_without_job_name (_download_folder , _cleanup , _stream_output , LocalSession , tmpdir ):
267
+
268
+ directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
269
+ with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
270
+ side_effect = directories ):
271
+
272
+ instance_count = 2
273
+ image = 'my-image'
274
+ sagemaker_container = _SageMakerContainer ('local' , instance_count , image , sagemaker_session = LocalSession )
275
+ sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS , TRAINING_JOB_NAME )
276
+
277
+ docker_compose_file = os .path .join (sagemaker_container .container_root , 'docker-compose.yaml' )
278
+
279
+ with open (docker_compose_file , 'r' ) as f :
280
+ config = yaml .load (f )
281
+ for h in sagemaker_container .hosts :
282
+ assert 'TRAINING_JOB_NAME={}' .format (TRAINING_JOB_NAME ) in config ['services' ][h ]['environment' ]
283
+
284
+
262
285
@patch ('sagemaker.local.local_session.LocalSession' )
263
286
@patch ('sagemaker.local.image._stream_output' , side_effect = RuntimeError ('this is expected' ))
264
287
@patch ('subprocess.Popen' )
@@ -273,7 +296,7 @@ def test_train_error(_download_folder, _cleanup, popen, _stream_output, LocalSes
273
296
sagemaker_container = _SageMakerContainer ('local' , instance_count , image , sagemaker_session = sagemaker_session )
274
297
275
298
with pytest .raises (RuntimeError ) as e :
276
- sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS )
299
+ sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS , TRAINING_JOB_NAME )
277
300
278
301
assert 'this is expected' in str (e )
279
302
@@ -293,7 +316,7 @@ def test_train_local_code(_download_folder, _cleanup, popen, _stream_output,
293
316
sagemaker_container = _SageMakerContainer ('local' , instance_count , image ,
294
317
sagemaker_session = sagemaker_session )
295
318
296
- sagemaker_container .train (INPUT_DATA_CONFIG , LOCAL_CODE_HYPERPARAMETERS )
319
+ sagemaker_container .train (INPUT_DATA_CONFIG , LOCAL_CODE_HYPERPARAMETERS , TRAINING_JOB_NAME )
297
320
298
321
docker_compose_file = os .path .join (sagemaker_container .container_root ,
299
322
'docker-compose.yaml' )
0 commit comments