Skip to content

Commit b3aae36

Browse files
authored
Add prefetch parameter to PyTorchTileDBDataLoader and TensorflowTileDBDataset (#131)
1 parent 9e391d7 commit b3aae36

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

tiledb/ml/readers/pytorch.py

+5
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
def PyTorchTileDBDataLoader(
1919
x_array: tiledb.Array,
2020
y_array: tiledb.Array,
21+
*,
2122
batch_size: int,
2223
buffer_bytes: Optional[int] = None,
2324
shuffle_buffer_size: int = 0,
25+
prefetch: int = 2,
2426
x_attrs: Sequence[str] = (),
2527
y_attrs: Sequence[str] = (),
2628
num_workers: int = 0,
@@ -32,6 +34,8 @@ def PyTorchTileDBDataLoader(
3234
:param batch_size: Size of each batch.
3335
:param buffer_bytes: Maximum size (in bytes) of memory to allocate for reading
3436
from each array (default=`tiledb.default_ctx().config()["sm.memory_budget"]`).
37+
:param prefetch: Number of samples loaded in advance by each worker. Not applicable
38+
(and should not be given) when `num_workers` is 0.
3539
:param shuffle_buffer_size: Number of elements from which this dataset will sample.
3640
:param x_attrs: Attribute names of x_array.
3741
:param y_attrs: Attribute names of y_array.
@@ -49,6 +53,7 @@ def PyTorchTileDBDataLoader(
4953
x_array, y_array, buffer_bytes, shuffle_buffer_size, x_attrs, y_attrs
5054
),
5155
batch_size=batch_size,
56+
prefetch_factor=prefetch,
5257
num_workers=num_workers,
5358
collate_fn=CompositeCollator(
5459
(

tiledb/ml/readers/tensorflow.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
def TensorflowTileDBDataset(
2121
x_array: tiledb.Array,
2222
y_array: tiledb.Array,
23+
*,
2324
batch_size: int,
2425
buffer_bytes: Optional[int] = None,
2526
shuffle_buffer_size: int = 0,
27+
prefetch: int = tf.data.AUTOTUNE,
2628
x_attrs: Sequence[str] = (),
2729
y_attrs: Sequence[str] = (),
2830
) -> tf.data.Dataset:
@@ -34,6 +36,8 @@ def TensorflowTileDBDataset(
3436
:param buffer_bytes: Maximum size (in bytes) of memory to allocate for reading from
3537
each array (default=`tiledb.default_ctx().config()["sm.memory_budget"]`).
3638
:param shuffle_buffer_size: Number of elements from which this dataset will sample.
39+
:param prefetch: Maximum number of batches that will be buffered when prefetching.
40+
By default, the buffer size is dynamically tuned.
3741
:param x_attrs: Attribute names of x_array.
3842
:param y_attrs: Attribute names of y_array.
3943
"""
@@ -66,7 +70,7 @@ def TensorflowTileDBDataset(
6670
dataset = dataset.unbatch()
6771
if shuffle_buffer_size > 0:
6872
dataset = dataset.shuffle(shuffle_buffer_size)
69-
return dataset.batch(batch_size)
73+
return dataset.batch(batch_size).prefetch(prefetch)
7074

7175

7276
def _iter_tensor_specs(

0 commit comments

Comments
 (0)