Skip to content

Commit 7e5be1b

Browse files
authored
Replace shuffle parameter with shuffle buffer size (#126)
* Drop shuffle parameter from tensor_generator * Replace shuffle parameter with shuffle_buffer_size (unused) * Implement shuffle_buffer_size for TensorflowTileDBDataset * Change PyTorchTileDBDataLoader from class to function * Implement shuffle_buffer_size for PyTorchTileDBDataset
1 parent 891cc35 commit 7e5be1b

File tree

10 files changed

+119
-121
lines changed

10 files changed

+119
-121
lines changed

examples/cloud/models/pytorch_tiledb_cloud_ml_model_array.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
"\n",
3333
"epochs = 1\n",
3434
"batch_size_train = 128\n",
35-
"batch_size_test = 1000\n",
3635
"learning_rate = 0.01\n",
3736
"momentum = 0.5\n",
3837
"log_interval = 10\n",
@@ -119,7 +118,7 @@
119118
" torchvision.transforms.Normalize(\n",
120119
" (0.1307,), (0.3081,))\n",
121120
" ])),\n",
122-
" batch_size=batch_size_train, shuffle=True)"
121+
" batch_size=batch_size_train, shuffle_buffer_size=2*batch_size_train)"
123122
]
124123
},
125124
{

examples/models/pytorch_tiledb_models_example.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
"source": [
6868
"epochs = 1\n",
6969
"batch_size_train = 128\n",
70-
"batch_size_test = 1000\n",
7170
"learning_rate = 0.01\n",
7271
"momentum = 0.5\n",
7372
"log_interval = 10\n",
@@ -108,7 +107,7 @@
108107
" torchvision.transforms.Normalize(\n",
109108
" (0.1307,), (0.3081,))\n",
110109
" ])),\n",
111-
" batch_size=batch_size_train, shuffle=True)\n"
110+
" batch_size=batch_size_train, shuffle_buffer_size=2*batch_size_train)\n"
112111
]
113112
},
114113
{

examples/readers/pytorch_data_api_tiledb_dense.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@
372372
" train_loader = PyTorchTileDBDataLoader(x_array=x, y_array=y,\n",
373373
" batch_size=64,\n",
374374
" buffer_bytes=1024**2,\n",
375-
" shuffle=True)\n",
375+
" shuffle_buffer_size=128)\n",
376376
" net = Net(shape=(28, 28))\n",
377377
" criterion = nn.CrossEntropyLoss()\n",
378378
" optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)\n",

examples/readers/tensorflow_data_api_tiledb_dense.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@
340340
"with tiledb.open(training_images) as x, tiledb.open(training_labels) as y:\n",
341341
" tiledb_dataset = TensorflowTileDBDataset(\n",
342342
" x_array=x, y_array=y, x_attrs=['features'], y_attrs=['features'], \n",
343-
" batch_size=64, buffer_bytes=1024**2, shuffle=True\n",
343+
" batch_size=64, buffer_bytes=1024**2, shuffle_buffer_size=128\n",
344344
" )\n",
345345
" model.fit(tiledb_dataset, epochs=5)"
346346
]

tests/readers/test_pytorch.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def test_dataset(
2727
y_shape,
2828
num_attrs,
2929
pass_attrs,
30-
batch_size,
3130
buffer_bytes,
32-
shuffle,
31+
batch_size,
32+
shuffle_buffer_size,
3333
):
3434
with ingest_in_tiledb(
3535
tmpdir,
@@ -40,9 +40,7 @@ def test_dataset(
4040
num_attrs=num_attrs,
4141
pass_attrs=pass_attrs,
4242
) as kwargs:
43-
dataset = PyTorchTileDBDataset(
44-
buffer_bytes=buffer_bytes, shuffle=shuffle, **kwargs
45-
)
43+
dataset = PyTorchTileDBDataset(buffer_bytes=buffer_bytes, **kwargs)
4644
assert isinstance(dataset, torch.utils.data.IterableDataset)
4745
validate_tensor_generator(
4846
dataset, num_attrs, x_sparse, y_sparse, x_shape, y_shape
@@ -61,9 +59,9 @@ def test_dataloader(
6159
y_shape,
6260
num_attrs,
6361
pass_attrs,
64-
batch_size,
6562
buffer_bytes,
66-
shuffle,
63+
batch_size,
64+
shuffle_buffer_size,
6765
):
6866
if num_workers and (x_sparse or y_sparse):
6967
pytest.skip("multiple workers not supported with sparse arrays")
@@ -81,7 +79,7 @@ def test_dataloader(
8179
num_workers=num_workers,
8280
buffer_bytes=buffer_bytes,
8381
batch_size=batch_size,
84-
shuffle=shuffle,
82+
shuffle_buffer_size=shuffle_buffer_size,
8583
**kwargs
8684
)
8785
assert isinstance(dataloader, torch.utils.data.DataLoader)
@@ -123,9 +121,9 @@ def test_unequal_num_rows(
123121
y_shape,
124122
num_attrs,
125123
pass_attrs,
126-
batch_size,
127124
buffer_bytes,
128-
shuffle,
125+
batch_size,
126+
shuffle_buffer_size,
129127
):
130128
with ingest_in_tiledb(
131129
tmpdir,
@@ -142,12 +140,12 @@ def test_unequal_num_rows(
142140
num_workers=num_workers,
143141
buffer_bytes=buffer_bytes,
144142
batch_size=batch_size,
145-
shuffle=shuffle,
143+
shuffle_buffer_size=shuffle_buffer_size,
146144
**kwargs
147145
)
148146
assert "X and Y arrays must have the same number of rows" in str(ex.value)
149147

150-
@parametrize_for_dataset(x_sparse=[True], shuffle=[False])
148+
@parametrize_for_dataset(x_sparse=[True], shuffle_buffer_size=[0])
151149
def test_sparse_read_order(
152150
self,
153151
tmpdir,
@@ -158,9 +156,9 @@ def test_sparse_read_order(
158156
y_shape,
159157
num_attrs,
160158
pass_attrs,
161-
batch_size,
162159
buffer_bytes,
163-
shuffle,
160+
batch_size,
161+
shuffle_buffer_size,
164162
):
165163
x_data = rand_array(num_rows, *x_shape, sparse=x_sparse)
166164
with ingest_in_tiledb(
@@ -175,7 +173,7 @@ def test_sparse_read_order(
175173
dataloader = PyTorchTileDBDataLoader(
176174
buffer_bytes=buffer_bytes,
177175
batch_size=batch_size,
178-
shuffle=shuffle,
176+
shuffle_buffer_size=shuffle_buffer_size,
179177
**kwargs
180178
)
181179
generated_x_data = np.concatenate(

tests/readers/test_tensorflow.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def test_dataset(
3838
y_shape,
3939
num_attrs,
4040
pass_attrs,
41-
batch_size,
4241
buffer_bytes,
43-
shuffle,
42+
batch_size,
43+
shuffle_buffer_size,
4444
):
4545
with ingest_in_tiledb(
4646
tmpdir,
@@ -54,7 +54,7 @@ def test_dataset(
5454
dataset = TensorflowTileDBDataset(
5555
buffer_bytes=buffer_bytes,
5656
batch_size=batch_size,
57-
shuffle=shuffle,
57+
shuffle_buffer_size=shuffle_buffer_size,
5858
**kwargs,
5959
)
6060
assert isinstance(dataset, tf.data.Dataset)
@@ -67,7 +67,6 @@ def test_dataset(
6767
# covered so test it explicitly.
6868
generator = tensor_generator(
6969
buffer_bytes=buffer_bytes,
70-
shuffle=shuffle,
7170
sparse_tensor_generator_cls=TensorflowSparseTileDBTensorGenerator,
7271
**kwargs,
7372
)
@@ -87,9 +86,9 @@ def test_unequal_num_rows(
8786
y_shape,
8887
num_attrs,
8988
pass_attrs,
90-
batch_size,
9189
buffer_bytes,
92-
shuffle,
90+
batch_size,
91+
shuffle_buffer_size,
9392
):
9493
with ingest_in_tiledb(
9594
tmpdir,
@@ -105,12 +104,12 @@ def test_unequal_num_rows(
105104
TensorflowTileDBDataset(
106105
buffer_bytes=buffer_bytes,
107106
batch_size=batch_size,
108-
shuffle=shuffle,
107+
shuffle_buffer_size=shuffle_buffer_size,
109108
**kwargs,
110109
)
111110
assert "X and Y arrays must have the same number of rows" in str(ex.value)
112111

113-
@parametrize_for_dataset(x_sparse=[True], shuffle=[False])
112+
@parametrize_for_dataset(x_sparse=[True], shuffle_buffer_size=[0])
114113
def test_sparse_read_order(
115114
self,
116115
tmpdir,
@@ -121,9 +120,9 @@ def test_sparse_read_order(
121120
y_shape,
122121
num_attrs,
123122
pass_attrs,
124-
batch_size,
125123
buffer_bytes,
126-
shuffle,
124+
batch_size,
125+
shuffle_buffer_size,
127126
):
128127
x_data = rand_array(num_rows, *x_shape, sparse=x_sparse)
129128
with ingest_in_tiledb(
@@ -138,7 +137,7 @@ def test_sparse_read_order(
138137
dataset = TensorflowTileDBDataset(
139138
buffer_bytes=buffer_bytes,
140139
batch_size=batch_size,
141-
shuffle=shuffle,
140+
shuffle_buffer_size=shuffle_buffer_size,
142141
**kwargs,
143142
)
144143
generated_x_data = np.concatenate(

tests/readers/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def parametrize_for_dataset(
2020
pass_attrs=(True, False),
2121
batch_size=(8,),
2222
buffer_bytes=(1024, None),
23-
shuffle=(True, False),
23+
shuffle_buffer_size=(0, 16),
2424
):
2525
def is_valid_combination(t):
2626
x_sparse_, y_sparse_, x_shape_, y_shape_, *_ = t
@@ -36,9 +36,9 @@ def is_valid_combination(t):
3636
"y_shape",
3737
"num_attrs",
3838
"pass_attrs",
39-
"batch_size",
4039
"buffer_bytes",
41-
"shuffle",
40+
"batch_size",
41+
"shuffle_buffer_size",
4242
]
4343
argvalues = filter(
4444
is_valid_combination,
@@ -49,9 +49,9 @@ def is_valid_combination(t):
4949
y_shape,
5050
num_attrs,
5151
pass_attrs,
52-
batch_size,
5352
buffer_bytes,
54-
shuffle,
53+
batch_size,
54+
shuffle_buffer_size,
5555
),
5656
)
5757
return pytest.mark.parametrize(argnames, argvalues)

tiledb/ml/readers/_batch_utils.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,6 @@ def read_buffer(self, array_slice: slice) -> None:
3636
:param array_slice: Requested array slice.
3737
"""
3838

39-
@abstractmethod
40-
def shuffle_buffer(self, buffer_slice: slice, row_idxs: np.ndarray) -> None:
41-
"""
42-
Shuffle a slice of the current buffer.
43-
44-
Must be called after `read_buffer`.
45-
46-
:param buffer_slice: Slice of the current buffer to shuffle.
47-
:param row_idxs: Shuffled indices; a shuffled version of
48-
`np.arange(0, buffer_slice.stop - buffer_slice.start)`
49-
"""
50-
5139
@abstractmethod
5240
def iter_tensors(self, buffer_slice: slice) -> Iterator[Tensor]:
5341
"""
@@ -63,10 +51,6 @@ class TileDBNumpyGenerator(TileDBTensorGenerator[np.ndarray]):
6351
def read_buffer(self, array_slice: slice) -> None:
6452
self._buf_arrays = tuple(self._query[array_slice].values())
6553

66-
def shuffle_buffer(self, buffer_slice: slice, row_idxs: np.ndarray) -> None:
67-
for buf_array in self._buf_arrays:
68-
buf_array[buffer_slice] = buf_array[buffer_slice.start + row_idxs]
69-
7054
def iter_tensors(self, buffer_slice: slice) -> Iterator[np.ndarray]:
7155
for buf_array in self._buf_arrays:
7256
yield buf_array[buffer_slice]
@@ -105,10 +89,6 @@ def read_buffer(self, array_slice: slice) -> None:
10589
for data in buffer.values()
10690
)
10791

108-
def shuffle_buffer(self, buffer_slice: slice, row_idxs: np.ndarray) -> None:
109-
for buf_csr in self._buf_csrs:
110-
buf_csr[buffer_slice] = buf_csr[buffer_slice.start + row_idxs]
111-
11292
def iter_tensors(self, buffer_slice: slice) -> Iterator[Tensor]:
11393
for buf_csr, dtype in zip(self._buf_csrs, self._attr_dtypes):
11494
batch_csr = buf_csr[buffer_slice]
@@ -137,7 +117,6 @@ def tensor_generator(
137117
x_array: tiledb.Array,
138118
y_array: tiledb.Array,
139119
buffer_bytes: Optional[int] = None,
140-
shuffle: bool = False,
141120
x_attrs: Sequence[str] = (),
142121
y_attrs: Sequence[str] = (),
143122
start_offset: int = 0,
@@ -159,7 +138,6 @@ def tensor_generator(
159138
:param y_array: TileDB array of the labels.
160139
:param buffer_bytes: Maximum size (in bytes) of memory to allocate for reading from
161140
each array (default=`tiledb.default_ctx().config()["sm.memory_budget"]`).
162-
:param shuffle: True for shuffling rows.
163141
:param x_attrs: Attribute names of x_array; defaults to all x_array attributes.
164142
:param y_attrs: Attribute names of y_array; defaults to all y_array attributes
165143
:param start_offset: Start row offset; defaults to 0.
@@ -205,12 +183,6 @@ def get_buffer_size_generator(
205183
elif batch.y_read_slice:
206184
y_gen.read_buffer(batch.y_read_slice)
207185

208-
if shuffle:
209-
row_idxs = np.arange(batch.size)
210-
np.random.shuffle(row_idxs)
211-
x_gen.shuffle_buffer(batch.x_buffer_slice, row_idxs)
212-
y_gen.shuffle_buffer(batch.y_buffer_slice, row_idxs)
213-
214186
x_tensors = x_gen.iter_tensors(batch.x_buffer_slice)
215187
y_tensors = y_gen.iter_tensors(batch.y_buffer_slice)
216188
yield (*x_tensors, *y_tensors)

0 commit comments

Comments
 (0)