Skip to content

Commit 08d6a36

Browse files
committed
Create temporary file within ThreadPoolExecutor
1 parent f597395 commit 08d6a36

File tree

1 file changed

+38
-45
lines changed

1 file changed

+38
-45
lines changed

pytd/writer.py

Lines changed: 38 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -439,46 +439,9 @@ def write_dataframe(
439439
fmt = "msgpack"
440440

441441
_cast_dtypes(dataframe, keep_list=keep_list)
442+
self._bulk_import(table, dataframe, if_exists, fmt, max_workers=max_workers, chunk_record_size=chunk_record_size)
442443

443-
with ExitStack() as stack:
444-
fps = []
445-
if fmt == "csv":
446-
fp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False)
447-
stack.callback(os.unlink, fp.name)
448-
stack.callback(fp.close)
449-
dataframe.to_csv(fp.name)
450-
fps.append(fp)
451-
elif fmt == "msgpack":
452-
_replace_pd_na(dataframe)
453-
num_rows = len(dataframe)
454-
# chunk number of records should not exceed 200 to avoid OSError
455-
_chunk_record_size = max(chunk_record_size, num_rows//200)
456-
try:
457-
for start in range(0, num_rows, _chunk_record_size):
458-
records = dataframe.iloc[
459-
start : start + _chunk_record_size
460-
].to_dict(orient="records")
461-
fp = tempfile.NamedTemporaryFile(
462-
suffix=".msgpack.gz", delete=False
463-
)
464-
fp = self._write_msgpack_stream(records, fp)
465-
fps.append(fp)
466-
stack.callback(os.unlink, fp.name)
467-
stack.callback(fp.close)
468-
except OSError as e:
469-
raise RuntimeError(
470-
"failed to create a temporary file. "
471-
"Larger chunk_record_size may mitigate the issue."
472-
) from e
473-
else:
474-
raise ValueError(
475-
f"unsupported format '{fmt}' for bulk import. "
476-
"should be 'csv' or 'msgpack'"
477-
)
478-
self._bulk_import(table, fps, if_exists, fmt, max_workers=max_workers)
479-
stack.close()
480-
481-
def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
444+
def _bulk_import(self, table, dataframe, if_exists, fmt="csv", max_workers=5, chunk_record_size=10_000):
482445
"""Write a specified CSV file to a Treasure Data table.
483446
484447
This method uploads the file to Treasure Data via bulk import API.
@@ -488,8 +451,7 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
488451
table : :class:`pytd.table.Table`
489452
Target table.
490453
491-
file_likes : List of file like objects
492-
Data in this file will be loaded to a target table.
454+
dataframe : DataFrame to be uploaded
493455
494456
if_exists : str, {'error', 'overwrite', 'append', 'ignore'}
495457
What happens when a target table already exists.
@@ -505,6 +467,10 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
505467
max_workers : int, optional, default: 5
506468
The maximum number of threads that can be used to execute the given calls.
507469
This is used only when ``fmt`` is ``msgpack``.
470+
471+
chunk_record_size : int, optional, default: 10_000
472+
The number of records to be written in a single file. This is used only when
473+
``fmt`` is ``msgpack``.
508474
"""
509475
params = None
510476
if table.exists:
@@ -530,11 +496,30 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
530496
session_name, table.database, table.table, params=params
531497
)
532498
s_time = time.time()
499+
file_paths = []
533500
try:
534501
logger.info(f"uploading data converted into a {fmt} file")
535-
if fmt == "msgpack":
502+
if fmt == "csv":
503+
fp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False)
504+
file_paths.append(fp.name)
505+
dataframe.to_csv(fp.name)
506+
bulk_import.upload_file("part", fmt, fp)
507+
os.unlink(fp.name)
508+
fp.close()
509+
elif fmt == "msgpack":
510+
_replace_pd_na(dataframe)
511+
num_rows = len(dataframe)
512+
536513
with ThreadPoolExecutor(max_workers=max_workers) as executor:
537-
for i, fp in enumerate(file_likes):
514+
for i, start in enumerate(range(0, num_rows, chunk_record_size)):
515+
records = dataframe.iloc[
516+
start : start + chunk_record_size
517+
].to_dict(orient="records")
518+
fp = tempfile.NamedTemporaryFile(
519+
suffix=".msgpack.gz", delete=False
520+
)
521+
file_paths.append(fp.name)
522+
fp = self._write_msgpack_stream(records, fp)
538523
fsize = fp.tell()
539524
fp.seek(0)
540525
executor.submit(
@@ -544,13 +529,21 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
544529
fsize,
545530
)
546531
logger.debug(f"to upload {fp.name} to TD. File size: {fsize}B")
532+
os.unlink(fp.name)
533+
fp.close()
547534
else:
548-
fp = file_likes[0]
549-
bulk_import.upload_file("part", fmt, fp)
535+
raise ValueError(
536+
f"unsupported format '{fmt}' for bulk import. "
537+
"should be 'csv' or 'msgpack'"
538+
)
550539
bulk_import.freeze()
551540
except Exception as e:
552541
bulk_import.delete()
553542
raise RuntimeError(f"failed to upload file: {e}")
543+
finally:
544+
for fp in file_paths:
545+
if os.path.exists(fp):
546+
os.unlink(fp)
554547

555548
logger.debug(f"uploaded data in {time.time() - s_time:.2f} sec")
556549

0 commit comments

Comments
 (0)