Skip to content

Commit c72121b

Browse files
authored
Fix LocalMode s3 training data when there is a trailing slash. (#208)
If the S3 training location is passed as s3://my_bucket/training/ (with a trailing slash) Then the objects would have incorrect filenames locally and the container would not find them. This fixes this bug.
1 parent e9ed907 commit c72121b

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ CHANGELOG
88
* bug-fix: Unit Tests: Improve unit test runtime
99
* bug-fix: Estimators: Fix attach for LDA
1010
* bug-fix: Estimators: allow code_location to have no key prefix
11+
* bug-fix: Local Mode: Fix s3 training data download when there is a trailing slash
12+
1113

1214
1.4.1
1315
=====

src/sagemaker/local/image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,15 @@ def _download_folder(self, bucket_name, prefix, target):
267267

268268
for obj_sum in bucket.objects.filter(Prefix=prefix):
269269
obj = s3.Object(obj_sum.bucket_name, obj_sum.key)
270-
file_path = os.path.join(target, obj_sum.key[len(prefix) + 1:])
270+
s3_relative_path = obj_sum.key[len(prefix):].lstrip('/')
271+
file_path = os.path.join(target, s3_relative_path)
271272

272273
try:
273274
os.makedirs(os.path.dirname(file_path))
274275
except OSError as exc:
275276
if exc.errno != errno.EEXIST:
276277
raise
277278
pass
278-
279279
obj.download_file(file_path)
280280

281281
def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters):

tests/unit/test_image.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,15 @@ def test_download_folder(makedirs):
366366
calls = [call(os.path.join('/tmp', 'train/train_data.csv')),
367367
call(os.path.join('/tmp', 'train/validation_data.csv'))]
368368
obj_mock.download_file.assert_has_calls(calls)
369+
obj_mock.reset_mock()
370+
371+
# Testing with a trailing slash for the prefix.
372+
sagemaker_container._download_folder(BUCKET_NAME, '/prefix/', '/tmp')
373+
obj_mock.download_file.assert_called()
374+
calls = [call(os.path.join('/tmp', 'train/train_data.csv')),
375+
call(os.path.join('/tmp', 'train/validation_data.csv'))]
376+
377+
obj_mock.download_file.assert_has_calls(calls)
369378

370379

371380
def test_ecr_login_non_ecr():

0 commit comments

Comments
 (0)