Skip to content

Commit 637d360

Browse files
Add job_name parameter for local mode (#424)
1 parent d9c03a1 commit 637d360

File tree

5 files changed

+38
-13
lines changed

5 files changed

+38
-13
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CHANGELOG
88

99
* enhancement: Enable setting VPC config when creating/deploying models
1010
* enhancement: Local Mode: accept short lived credentials with a warning message
11+
* bug-fix: Local Mode: pass in job name as parameter for training environment variable
1112

1213
=======
1314
1.11.1

src/sagemaker/local/entities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, container):
4141
self.start_time = None
4242
self.end_time = None
4343

44-
def start(self, input_data_config, hyperparameters):
44+
def start(self, input_data_config, hyperparameters, job_name):
4545
for channel in input_data_config:
4646
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
4747
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
@@ -57,7 +57,7 @@ def start(self, input_data_config, hyperparameters):
5757
self.start = datetime.datetime.now()
5858
self.state = self._TRAINING
5959

60-
self.model_artifacts = self.container.train(input_data_config, hyperparameters)
60+
self.model_artifacts = self.container.train(input_data_config, hyperparameters, job_name)
6161
self.end = datetime.datetime.now()
6262
self.state = self._COMPLETED
6363

src/sagemaker/local/image.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,13 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None)
7979
self.container_root = None
8080
self.container = None
8181

82-
def train(self, input_data_config, hyperparameters):
82+
def train(self, input_data_config, hyperparameters, job_name):
8383
"""Run a training job locally using docker-compose.
8484
Args:
8585
input_data_config (dict): The Input Data Configuration, this contains data such as the
8686
channels to be used for training.
8787
hyperparameters (dict): The HyperParameters for the training job.
88+
job_name (str): Name of the local training job being run.
8889
8990
Returns (str): Location of the trained model.
9091
"""
@@ -109,7 +110,7 @@ def train(self, input_data_config, hyperparameters):
109110

110111
training_env_vars = {
111112
REGION_ENV_NAME: self.sagemaker_session.boto_region_name,
112-
TRAINING_JOB_NAME_ENV_NAME: json.loads(hyperparameters.get(sagemaker.model.JOB_NAME_PARAM_NAME)),
113+
TRAINING_JOB_NAME_ENV_NAME: job_name,
113114
}
114115
compose_data = self._generate_compose_file('train', additional_volumes=volumes,
115116
additional_env_vars=training_env_vars)

src/sagemaker/local/local_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def create_training_job(self, TrainingJobName, AlgorithmSpecification, InputData
7272
AlgorithmSpecification['TrainingImage'], self.sagemaker_session)
7373
training_job = _LocalTrainingJob(container)
7474
hyperparameters = kwargs['HyperParameters'] if 'HyperParameters' in kwargs else {}
75-
training_job.start(InputDataConfig, hyperparameters)
75+
training_job.start(InputDataConfig, hyperparameters, TrainingJobName)
7676

7777
LocalSagemakerClient._training_jobs[TrainingJobName] = training_job
7878

tests/unit/test_image.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
REGION = 'us-west-2'
3333
BUCKET_NAME = 'mybucket'
3434
EXPANDED_ROLE = 'arn:aws:iam::111111111111:role/ExpandedRole'
35+
TRAINING_JOB_NAME = 'my-job'
3536
INPUT_DATA_CONFIG = [
3637
{
3738
'ChannelName': 'a',
@@ -55,13 +56,12 @@
5556
]
5657
HYPERPARAMETERS = {'a': 1,
5758
'b': json.dumps('bee'),
58-
'sagemaker_submit_directory': json.dumps('s3://my_bucket/code'),
59-
'sagemaker_job_name': json.dumps('my-job')}
59+
'sagemaker_submit_directory': json.dumps('s3://my_bucket/code')}
60+
6061

6162
LOCAL_CODE_HYPERPARAMETERS = {'a': 1,
6263
'b': 2,
63-
'sagemaker_submit_directory': json.dumps('file:///tmp/code'),
64-
'sagemaker_job_name': json.dumps('my-job')}
64+
'sagemaker_submit_directory': json.dumps('file:///tmp/code')}
6565

6666

6767
@pytest.fixture()
@@ -230,7 +230,7 @@ def test_train(_download_folder, _cleanup, popen, _stream_output, LocalSession,
230230
instance_count = 2
231231
image = 'my-image'
232232
sagemaker_container = _SageMakerContainer('local', instance_count, image, sagemaker_session=sagemaker_session)
233-
sagemaker_container.train(INPUT_DATA_CONFIG, HYPERPARAMETERS)
233+
sagemaker_container.train(INPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME)
234234

235235
channel_dir = os.path.join(directories[1], 'b')
236236
download_folder_calls = [call('my-own-bucket', 'prefix', channel_dir)]
@@ -252,13 +252,36 @@ def test_train(_download_folder, _cleanup, popen, _stream_output, LocalSession,
252252
assert config['services'][h]['image'] == image
253253
assert config['services'][h]['command'] == 'train'
254254
assert 'AWS_REGION={}'.format(REGION) in config['services'][h]['environment']
255-
assert 'TRAINING_JOB_NAME=my-job' in config['services'][h]['environment']
255+
assert 'TRAINING_JOB_NAME={}'.format(TRAINING_JOB_NAME) in config['services'][h]['environment']
256256

257257
# assert that expected by sagemaker container output directories exist
258258
assert os.path.exists(os.path.join(sagemaker_container.container_root, 'output'))
259259
assert os.path.exists(os.path.join(sagemaker_container.container_root, 'output/data'))
260260

261261

262+
@patch('sagemaker.local.local_session.LocalSession')
263+
@patch('sagemaker.local.image._stream_output')
264+
@patch('sagemaker.local.image._SageMakerContainer._cleanup')
265+
@patch('sagemaker.local.image._SageMakerContainer._download_folder')
266+
def test_train_with_hyperparameters_without_job_name(_download_folder, _cleanup, _stream_output, LocalSession, tmpdir):
267+
268+
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
269+
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
270+
side_effect=directories):
271+
272+
instance_count = 2
273+
image = 'my-image'
274+
sagemaker_container = _SageMakerContainer('local', instance_count, image, sagemaker_session=LocalSession)
275+
sagemaker_container.train(INPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME)
276+
277+
docker_compose_file = os.path.join(sagemaker_container.container_root, 'docker-compose.yaml')
278+
279+
with open(docker_compose_file, 'r') as f:
280+
config = yaml.load(f)
281+
for h in sagemaker_container.hosts:
282+
assert 'TRAINING_JOB_NAME={}'.format(TRAINING_JOB_NAME) in config['services'][h]['environment']
283+
284+
262285
@patch('sagemaker.local.local_session.LocalSession')
263286
@patch('sagemaker.local.image._stream_output', side_effect=RuntimeError('this is expected'))
264287
@patch('subprocess.Popen')
@@ -273,7 +296,7 @@ def test_train_error(_download_folder, _cleanup, popen, _stream_output, LocalSes
273296
sagemaker_container = _SageMakerContainer('local', instance_count, image, sagemaker_session=sagemaker_session)
274297

275298
with pytest.raises(RuntimeError) as e:
276-
sagemaker_container.train(INPUT_DATA_CONFIG, HYPERPARAMETERS)
299+
sagemaker_container.train(INPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME)
277300

278301
assert 'this is expected' in str(e)
279302

@@ -293,7 +316,7 @@ def test_train_local_code(_download_folder, _cleanup, popen, _stream_output,
293316
sagemaker_container = _SageMakerContainer('local', instance_count, image,
294317
sagemaker_session=sagemaker_session)
295318

296-
sagemaker_container.train(INPUT_DATA_CONFIG, LOCAL_CODE_HYPERPARAMETERS)
319+
sagemaker_container.train(INPUT_DATA_CONFIG, LOCAL_CODE_HYPERPARAMETERS, TRAINING_JOB_NAME)
297320

298321
docker_compose_file = os.path.join(sagemaker_container.container_root,
299322
'docker-compose.yaml')

0 commit comments

Comments
 (0)