Skip to content

Commit 3468d9f

Browse files
authored
Merge pull request #3130 from AKSoo/patch-1
FIX: DataSink to S3 buckets
2 parents e5664a2 + 847553e commit 3468d9f

File tree

1 file changed

+7
-21
lines changed

1 file changed

+7
-21
lines changed

nipype/interfaces/io.py

+7-21
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,12 @@ class DataSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
212212
"""
213213

214214
# Init inputspec data attributes
215-
base_directory = Directory(desc="Path to the base directory for storing data.")
215+
base_directory = Str(desc="Path to the base directory for storing data.")
216216
container = Str(desc="Folder within base directory in which to store output")
217217
parameterization = traits.Bool(
218218
True, usedefault=True, desc="store output in parametrized structure"
219219
)
220-
strip_dir = Directory(desc="path to strip out of filename")
220+
strip_dir = Str(desc="path to strip out of filename")
221221
substitutions = InputMultiPath(
222222
traits.Tuple(Str, Str),
223223
desc=(
@@ -440,7 +440,6 @@ def _check_s3_base_dir(self):
440440
is not a valid S3 path, defaults to '<N/A>'
441441
"""
442442

443-
# Init variables
444443
s3_str = "s3://"
445444
bucket_name = "<N/A>"
446445
base_directory = self.inputs.base_directory
@@ -449,22 +448,10 @@ def _check_s3_base_dir(self):
449448
s3_flag = False
450449
return s3_flag, bucket_name
451450

452-
# Explicitly lower-case the "s3"
453-
if base_directory.lower().startswith(s3_str):
454-
base_dir_sp = base_directory.split("/")
455-
base_dir_sp[0] = base_dir_sp[0].lower()
456-
base_directory = "/".join(base_dir_sp)
457-
458-
# Check if 's3://' in base dir
459-
if base_directory.startswith(s3_str):
460-
# Expects bucket name to be 's3://bucket_name/base_dir/..'
461-
bucket_name = base_directory.split(s3_str)[1].split("/")[0]
462-
s3_flag = True
463-
# Otherwise it's just a normal datasink
464-
else:
465-
s3_flag = False
451+
s3_flag = base_directory.lower().startswith(s3_str)
452+
if s3_flag:
453+
bucket_name = base_directory[len(s3_str):].partition('/')[0]
466454

467-
# Return s3_flag
468455
return s3_flag, bucket_name
469456

470457
# Function to return AWS secure environment variables
@@ -618,13 +605,12 @@ def _upload_to_s3(self, bucket, src, dst):
618605

619606
from botocore.exceptions import ClientError
620607

621-
# Init variables
622608
s3_str = "s3://"
623609
s3_prefix = s3_str + bucket.name
624610

625611
# Explicitly lower-case the "s3"
626-
if dst[: len(s3_str)].lower() == s3_str:
627-
dst = s3_str + dst[len(s3_str) :]
612+
if dst.lower().startswith(s3_str):
613+
dst = s3_str + dst[len(s3_str):]
628614

629615
# If src is a directory, collect files (this assumes dst is a dir too)
630616
if os.path.isdir(src):

0 commit comments

Comments
 (0)