Skip to content

Commit caf15c8

Browse files
authored
put_file: support concurrent multipart uploads with max_concurrency (#848)
1 parent 74f4d95 commit caf15c8

File tree

1 file changed

+76
-19
lines changed

1 file changed

+76
-19
lines changed

s3fs/core.py

+76-19
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,14 @@ class S3FileSystem(AsyncFileSystem):
235235
session : aiobotocore AioSession object to be used for all connections.
236236
This session will be used inplace of creating a new session inside S3FileSystem.
237237
For example: aiobotocore.session.AioSession(profile='test_user')
238+
max_concurrency : int (1)
239+
The maximum number of concurrent transfers to use per file for multipart
240+
upload (``put()``) operations. Defaults to 1 (sequential). When used in
241+
conjunction with ``S3FileSystem.put(batch_size=...)`` the maximum number of
242+
simultaneous connections is ``max_concurrency * batch_size``. We may extend
243+
this parameter to affect ``pipe()``, ``cat()`` and ``get()``. Increasing this
244+
value will result in higher memory usage during multipart upload operations (by
245+
``max_concurrency * chunksize`` bytes per file).
238246
239247
The following parameters are passed on to fsspec:
240248
@@ -282,6 +290,7 @@ def __init__(
282290
cache_regions=False,
283291
asynchronous=False,
284292
loop=None,
293+
max_concurrency=1,
285294
**kwargs,
286295
):
287296
if key and username:
@@ -319,6 +328,9 @@ def __init__(
319328
self.cache_regions = cache_regions
320329
self._s3 = None
321330
self.session = session
331+
if max_concurrency < 1:
332+
raise ValueError("max_concurrency must be >= 1")
333+
self.max_concurrency = max_concurrency
322334

323335
@property
324336
def s3(self):
@@ -1140,7 +1152,13 @@ async def _pipe_file(self, path, data, chunksize=50 * 2**20, **kwargs):
11401152
self.invalidate_cache(path)
11411153

11421154
async def _put_file(
1143-
self, lpath, rpath, callback=_DEFAULT_CALLBACK, chunksize=50 * 2**20, **kwargs
1155+
self,
1156+
lpath,
1157+
rpath,
1158+
callback=_DEFAULT_CALLBACK,
1159+
chunksize=50 * 2**20,
1160+
max_concurrency=None,
1161+
**kwargs,
11441162
):
11451163
bucket, key, _ = self.split_path(rpath)
11461164
if os.path.isdir(lpath):
@@ -1169,24 +1187,15 @@ async def _put_file(
11691187
mpu = await self._call_s3(
11701188
"create_multipart_upload", Bucket=bucket, Key=key, **kwargs
11711189
)
1172-
1173-
out = []
1174-
while True:
1175-
chunk = f0.read(chunksize)
1176-
if not chunk:
1177-
break
1178-
out.append(
1179-
await self._call_s3(
1180-
"upload_part",
1181-
Bucket=bucket,
1182-
PartNumber=len(out) + 1,
1183-
UploadId=mpu["UploadId"],
1184-
Body=chunk,
1185-
Key=key,
1186-
)
1187-
)
1188-
callback.relative_update(len(chunk))
1189-
1190+
out = await self._upload_file_part_concurrent(
1191+
bucket,
1192+
key,
1193+
mpu,
1194+
f0,
1195+
callback=callback,
1196+
chunksize=chunksize,
1197+
max_concurrency=max_concurrency,
1198+
)
11901199
parts = [
11911200
{"PartNumber": i + 1, "ETag": o["ETag"]} for i, o in enumerate(out)
11921201
]
@@ -1201,6 +1210,54 @@ async def _put_file(
12011210
self.invalidate_cache(rpath)
12021211
rpath = self._parent(rpath)
12031212

1213+
async def _upload_file_part_concurrent(
1214+
self,
1215+
bucket,
1216+
key,
1217+
mpu,
1218+
f0,
1219+
callback=_DEFAULT_CALLBACK,
1220+
chunksize=50 * 2**20,
1221+
max_concurrency=None,
1222+
):
1223+
max_concurrency = max_concurrency or self.max_concurrency
1224+
if max_concurrency < 1:
1225+
raise ValueError("max_concurrency must be >= 1")
1226+
1227+
async def _upload_chunk(chunk, part_number):
1228+
result = await self._call_s3(
1229+
"upload_part",
1230+
Bucket=bucket,
1231+
PartNumber=part_number,
1232+
UploadId=mpu["UploadId"],
1233+
Body=chunk,
1234+
Key=key,
1235+
)
1236+
callback.relative_update(len(chunk))
1237+
return result
1238+
1239+
out = []
1240+
while True:
1241+
chunks = []
1242+
for i in range(max_concurrency):
1243+
chunk = f0.read(chunksize)
1244+
if chunk:
1245+
chunks.append(chunk)
1246+
if not chunks:
1247+
break
1248+
if len(chunks) > 1:
1249+
out.extend(
1250+
await asyncio.gather(
1251+
*[
1252+
_upload_chunk(chunk, len(out) + i)
1253+
for i, chunk in enumerate(chunks, 1)
1254+
]
1255+
)
1256+
)
1257+
else:
1258+
out.append(await _upload_chunk(chunk, len(out) + 1))
1259+
return out
1260+
12041261
async def _get_file(
12051262
self, rpath, lpath, callback=_DEFAULT_CALLBACK, version_id=None
12061263
):

0 commit comments

Comments
 (0)