Skip to content

Commit 7357c4e

Browse files
authored
Enable Multithreading on msgpack Chunking in BulkImportWriter (#142)
* Enable multithreading on msgpack chunking * remove redundant parameter from different PR * Slight refactor
1 parent 6490412 commit 7357c4e

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

pytd/writer.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -452,19 +452,29 @@ def write_dataframe(
452452
_replace_pd_na(dataframe)
453453
num_rows = len(dataframe)
454454
# chunk number of records should not exceed 200 to avoid OSError
455-
_chunk_record_size = max(chunk_record_size, num_rows//200)
455+
_chunk_record_size = max(chunk_record_size, num_rows // 200)
456456
try:
457-
for start in range(0, num_rows, _chunk_record_size):
458-
records = dataframe.iloc[
459-
start : start + _chunk_record_size
460-
].to_dict(orient="records")
461-
fp = tempfile.NamedTemporaryFile(
462-
suffix=".msgpack.gz", delete=False
463-
)
464-
fp = self._write_msgpack_stream(records, fp)
465-
fps.append(fp)
466-
stack.callback(os.unlink, fp.name)
467-
stack.callback(fp.close)
457+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
458+
futures = []
459+
for start in range(0, num_rows, _chunk_record_size):
460+
records = dataframe.iloc[
461+
start : start + _chunk_record_size
462+
].to_dict(orient="records")
463+
fp = tempfile.NamedTemporaryFile(
464+
suffix=".msgpack.gz", delete=False
465+
)
466+
futures.append(
467+
(
468+
start,
469+
executor.submit(
470+
self._write_msgpack_stream, records, fp
471+
),
472+
)
473+
)
474+
stack.callback(os.unlink, fp.name)
475+
stack.callback(fp.close)
476+
for start, future in sorted(futures):
477+
fps.append(future.result())
468478
except OSError as e:
469479
raise RuntimeError(
470480
"failed to create a temporary file. "

0 commit comments

Comments
 (0)