Skip to content

Commit f6e09b6

Browse files
Pass Through Arguments Tensorflow (#170)
* removed batch size, shuffle and prefetch from Tensorflow reader * tensorflow notebook update * PR changes
1 parent 5943ced commit f6e09b6

File tree

4 files changed

+43
-51
lines changed

4 files changed

+43
-51
lines changed

examples/readers/tensorflow_data_api_tiledb_dense.ipynb

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
"execution_count": 2,
3939
"metadata": {
4040
"pycharm": {
41-
"is_executing": true,
4241
"name": "#%%\n"
4342
}
4443
},
@@ -78,7 +77,6 @@
7877
"execution_count": 3,
7978
"metadata": {
8079
"pycharm": {
81-
"is_executing": true,
8280
"name": "#%%\n"
8381
}
8482
},
@@ -126,7 +124,6 @@
126124
"execution_count": 4,
127125
"metadata": {
128126
"pycharm": {
129-
"is_executing": true,
130127
"name": "#%%\n"
131128
}
132129
},
@@ -158,7 +155,6 @@
158155
"execution_count": 5,
159156
"metadata": {
160157
"pycharm": {
161-
"is_executing": true,
162158
"name": "#%%\n"
163159
}
164160
},
@@ -220,15 +216,14 @@
220216
"execution_count": 6,
221217
"metadata": {
222218
"pycharm": {
223-
"is_executing": true,
224219
"name": "#%%\n"
225220
}
226221
},
227222
"outputs": [
228223
{
229224
"data": {
230225
"text/plain": [
231-
"<matplotlib.image.AxesImage at 0x15cde5e80>"
226+
"<matplotlib.image.AxesImage at 0x157c2e8e0>"
232227
]
233228
},
234229
"execution_count": 6,
@@ -270,7 +265,6 @@
270265
"execution_count": 7,
271266
"metadata": {
272267
"pycharm": {
273-
"is_executing": true,
274268
"name": "#%%\n"
275269
}
276270
},
@@ -309,7 +303,6 @@
309303
"execution_count": 8,
310304
"metadata": {
311305
"pycharm": {
312-
"is_executing": true,
313306
"name": "#%%\n"
314307
}
315308
},
@@ -325,23 +318,23 @@
325318
"name": "stderr",
326319
"output_type": "stream",
327320
"text": [
328-
"2022-07-05 18:13:38.031008: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
321+
"2022-07-26 13:24:54.724292: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
329322
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
330323
]
331324
},
332325
{
333326
"name": "stdout",
334327
"output_type": "stream",
335328
"text": [
336-
"938/938 [==============================] - 4s 4ms/step - loss: 0.3456 - accuracy: 0.9013\n",
329+
"938/938 [==============================] - 5s 4ms/step - loss: 0.3414 - accuracy: 0.9013\n",
337330
"Epoch 2/5\n",
338-
"938/938 [==============================] - 3s 3ms/step - loss: 0.1683 - accuracy: 0.9510\n",
331+
"938/938 [==============================] - 4s 4ms/step - loss: 0.1662 - accuracy: 0.9515\n",
339332
"Epoch 3/5\n",
340-
"938/938 [==============================] - 3s 3ms/step - loss: 0.1254 - accuracy: 0.9631\n",
333+
"938/938 [==============================] - 4s 4ms/step - loss: 0.1252 - accuracy: 0.9631\n",
341334
"Epoch 4/5\n",
342-
"938/938 [==============================] - 3s 3ms/step - loss: 0.1037 - accuracy: 0.9687\n",
335+
"938/938 [==============================] - 4s 4ms/step - loss: 0.1021 - accuracy: 0.9694\n",
343336
"Epoch 5/5\n",
344-
"938/938 [==============================] - 3s 3ms/step - loss: 0.0873 - accuracy: 0.9730\n"
337+
"938/938 [==============================] - 4s 4ms/step - loss: 0.0878 - accuracy: 0.9735\n"
345338
]
346339
}
347340
],
@@ -355,8 +348,9 @@
355348
" tiledb_dataset = TensorflowTileDBDataset(\n",
356349
" ArrayParams(array=x, fields=['features']),\n",
357350
" ArrayParams(array=y, fields=['features']),\n",
358-
" batch_size=64, shuffle_buffer_size=128\n",
359-
" )\n",
351+
" num_workers=2 \n",
352+
" ) \n",
353+
" tiledb_dataset = tiledb_dataset.batch(64).shuffle(128)\n",
360354
" model.fit(tiledb_dataset, epochs=5)"
361355
]
362356
},
@@ -365,7 +359,6 @@
365359
"execution_count": 9,
366360
"metadata": {
367361
"pycharm": {
368-
"is_executing": true,
369362
"name": "#%%\n"
370363
}
371364
},
@@ -397,13 +390,6 @@
397390
"source": [
398391
"model.summary()"
399392
]
400-
},
401-
{
402-
"cell_type": "code",
403-
"execution_count": null,
404-
"metadata": {},
405-
"outputs": [],
406-
"source": []
407393
}
408394
],
409395
"metadata": {

examples/readers/tensorflow_data_api_tiledb_sparse.ipynb

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -669,39 +669,45 @@
669669
}
670670
},
671671
"outputs": [
672+
{
673+
"name": "stdout",
674+
"output_type": "stream",
675+
"text": [
676+
"Epoch 1/2\n"
677+
]
678+
},
672679
{
673680
"name": "stderr",
674681
"output_type": "stream",
675682
"text": [
676-
"2022-06-22 17:31:04.946401: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
677-
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
678-
"2022-06-22 17:31:05.168238: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)\n",
679-
"/Users/konstantinostsitsimpikos/tileroot/TileDB-ML/venv2/lib/python3.9/site-packages/tensorflow/python/framework/indexed_slices.py:447: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor(\"gradient_tape/sequential/dense/embedding_lookup_sparse/Reshape_1:0\", shape=(None,), dtype=int32), values=Tensor(\"gradient_tape/sequential/dense/embedding_lookup_sparse/Reshape:0\", shape=(None, 1), dtype=float32), dense_shape=Tensor(\"gradient_tape/sequential/dense/embedding_lookup_sparse/Cast:0\", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
680-
" warnings.warn(\n"
683+
"2022-07-26 13:28:43.457755: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
684+
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
681685
]
682686
},
683687
{
684688
"name": "stdout",
685689
"output_type": "stream",
686690
"text": [
687-
"Epoch 1/2\n",
688-
"3125/3125 [==============================] - 2s 485us/step - loss: 0.0000e+00 - accuracy: 0.0607\n",
691+
"3125/3125 [==============================] - 3s 832us/step - loss: 0.0000e+00 - accuracy: 0.0607\n",
689692
"Epoch 2/2\n",
690-
"3125/3125 [==============================] - 2s 464us/step - loss: 0.0000e+00 - accuracy: 0.0611\n"
693+
"3125/3125 [==============================] - 3s 697us/step - loss: 0.0000e+00 - accuracy: 0.0611\n"
691694
]
692695
}
693696
],
694697
"source": [
695698
"from tiledb.ml.readers.tensorflow import TensorflowTileDBDataset, ArrayParams\n",
696699
"\n",
700+
"import warnings\n",
701+
"warnings.filterwarnings(\"ignore\")\n",
702+
"\n",
697703
"ctx = tiledb.Ctx({\"sm.memory_budget\": 1024**2, \"py.init_buffer_bytes\": 1024**2})\n",
698704
"with tiledb.open(training_images, ctx=ctx) as x, tiledb.open(training_labels, ctx=ctx) as y:\n",
699705
" tiledb_dataset = TensorflowTileDBDataset(\n",
700706
" ArrayParams(array=x, fields=['features']),\n",
701-
" ArrayParams(array=y, fields=['features']),\n",
702-
" batch_size=32)\n",
707+
" ArrayParams(array=y, fields=['features']))\n",
703708
" model = design_model(input_shape=user_movie.shape[1])\n",
704-
" model.fit(tiledb_dataset, epochs=2, batch_size=32)"
709+
" tiledb_dataset = tiledb_dataset.batch(32)\n",
710+
" model.fit(tiledb_dataset, epochs=2)"
705711
]
706712
}
707713
],
@@ -726,4 +732,4 @@
726732
},
727733
"nbformat": 4,
728734
"nbformat_minor": 4
729-
}
735+
}

tests/readers/test_tensorflow.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
from .utils import ingest_in_tiledb, parametrize_for_dataset, validate_tensor_generator
1010

1111

12+
def dataset_batching_shuffling(dataset: tf.data.Dataset, batch_size: int, shuffle_buffer_size: int) -> tf.data.Dataset:
13+
if shuffle_buffer_size > 0:
14+
dataset = dataset.shuffle(shuffle_buffer_size)
15+
return dataset.batch(batch_size)
16+
17+
1218
class TestTensorflowTileDBDataset:
1319
@parametrize_for_dataset()
1420
def test_dataset(
@@ -19,9 +25,12 @@ def test_dataset(
1925
dataset = TensorflowTileDBDataset(
2026
x_params,
2127
y_params,
28+
num_workers=num_workers,
29+
)
30+
dataset = dataset_batching_shuffling(
31+
dataset=dataset,
2232
batch_size=batch_size,
2333
shuffle_buffer_size=shuffle_buffer_size,
24-
num_workers=num_workers,
2534
)
2635
assert isinstance(dataset, tf.data.Dataset)
2736
validate_tensor_generator(
@@ -42,8 +51,6 @@ def test_unequal_num_keys(
4251
TensorflowTileDBDataset(
4352
x_params,
4453
y_params,
45-
batch_size=batch_size,
46-
shuffle_buffer_size=shuffle_buffer_size,
4754
num_workers=num_workers,
4855
)
4956
assert "All arrays must have the same key range" in str(ex.value)
@@ -62,9 +69,12 @@ def test_dataset_order(
6269
dataset = TensorflowTileDBDataset(
6370
x_params,
6471
y_params,
72+
num_workers=num_workers,
73+
)
74+
dataset = dataset_batching_shuffling(
75+
dataset=dataset,
6576
batch_size=batch_size,
6677
shuffle_buffer_size=shuffle_buffer_size,
67-
num_workers=num_workers,
6878
)
6979
# since num_fields is 0, fields are all the array attributes of each array
7080
# the first item of each batch corresponds to the first attribute (="data")

tiledb/ml/readers/tensorflow.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,10 @@
1313

1414
def TensorflowTileDBDataset(
1515
*all_array_params: ArrayParams,
16-
batch_size: int,
17-
shuffle_buffer_size: int = 0,
18-
prefetch: int = tf.data.AUTOTUNE,
1916
num_workers: int = 0,
2017
) -> tf.data.Dataset:
2118
"""Return a tf.data.Dataset for loading data from TileDB arrays.
22-
2319
:param all_array_params: One or more `ArrayParams` instances, one per TileDB array.
24-
:param batch_size: Size of each batch.
25-
:param shuffle_buffer_size: Number of elements from which this dataset will sample.
26-
:param prefetch: Maximum number of batches that will be buffered when prefetching.
27-
By default, the buffer size is dynamically tuned.
2820
:param num_workers: If greater than zero, create a threadpool of `num_workers` threads
2921
used to fetch inputs asynchronously and in parallel. Note: when `num_workers` > 1
3022
yielded batches may be shuffled even if `shuffle_buffer_size` is zero.
@@ -60,9 +52,7 @@ def key_range_dataset(key_range_idx: int) -> tf.data.Dataset:
6052
else:
6153
dataset = key_range_dataset(0)
6254

63-
if shuffle_buffer_size > 0:
64-
dataset = dataset.shuffle(shuffle_buffer_size)
65-
return dataset.batch(batch_size).prefetch(prefetch)
55+
return dataset
6656

6757

6858
_tensor_specs = {

0 commit comments

Comments
 (0)