Skip to content

Commit fc1b26d

Browse files
fix(trainers): Add support for DistributedDatasetsFromFunction in data adapters (#20829)
The is_tf_dataset() function in data adapters now recognizes DistributedDatasetsFromFunction as a valid TensorFlow dataset type. This allows for properly handling distributed datasets created via strategy.distribute_datasets_from_function() - Added test case to verify distributed datasets from function support
1 parent 295fdcc commit fc1b26d

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

keras/src/trainers/data_adapters/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def is_tf_dataset(x):
139139
if parent.__name__ in (
140140
"DatasetV2",
141141
"DistributedDataset",
142+
"DistributedDatasetsFromFunction",
142143
) and "tensorflow.python." in str(parent.__module__):
143144
return True
144145
return False

keras/src/trainers/data_adapters/tf_dataset_adapter_test.py

+65
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import tensorflow as tf
77
import torch
88

9+
from keras.src import Sequential
910
from keras.src import backend
11+
from keras.src import layers
1012
from keras.src import testing
1113
from keras.src.trainers.data_adapters import tf_dataset_adapter
1214

@@ -286,3 +288,66 @@ def test_tf_sparse_tensors(self):
286288
self.assertIsInstance(by, expected_class)
287289
self.assertEqual(bx.shape, (2, 4))
288290
self.assertEqual(by.shape, (2, 2))
291+
292+
def test_distributed_datasets_from_function_adapter_properties(self):
293+
strategy = tf.distribute.MirroredStrategy()
294+
295+
def dataset_fn(input_context):
296+
batch_size = input_context.get_per_replica_batch_size(
297+
global_batch_size=2
298+
)
299+
x = tf.random.uniform((32, 4))
300+
y = tf.random.uniform((32, 2))
301+
return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)
302+
303+
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
304+
adapter = tf_dataset_adapter.TFDatasetAdapter(dist_dataset)
305+
self.assertEqual(adapter.num_batches, 16)
306+
self.assertIsNone(adapter.batch_size)
307+
self.assertIsNone(adapter.has_partial_batch)
308+
self.assertIsNone(adapter.partial_batch_size)
309+
310+
if backend.backend() == "numpy":
311+
it = adapter.get_numpy_iterator()
312+
expected_class = np.ndarray
313+
elif backend.backend() == "tensorflow":
314+
it = adapter.get_tf_dataset()
315+
expected_class = tf.Tensor
316+
elif backend.backend() == "jax":
317+
it = adapter.get_jax_iterator()
318+
expected_class = np.ndarray
319+
elif backend.backend() == "torch":
320+
it = adapter.get_torch_dataloader()
321+
expected_class = torch.Tensor
322+
323+
batch_count = 0
324+
for batch in it:
325+
batch_count += 1
326+
self.assertEqual(len(batch), 2)
327+
data, labels = batch
328+
self.assertIsInstance(data, expected_class)
329+
self.assertIsInstance(labels, expected_class)
330+
self.assertEqual(data.shape, (2, 4))
331+
self.assertEqual(labels.shape, (2, 2))
332+
333+
self.assertEqual(batch_count, 16)
334+
335+
@pytest.mark.requires_trainable_backend
336+
def test_distributed_datasets_from_function_model_integration(self):
337+
strategy = tf.distribute.MirroredStrategy()
338+
339+
def dataset_fn(input_context):
340+
batch_size = input_context.get_per_replica_batch_size(
341+
global_batch_size=2
342+
)
343+
x = tf.random.uniform((4, 1))
344+
y = tf.random.uniform((4, 2))
345+
return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)
346+
347+
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
348+
349+
model = Sequential([layers.Dense(2, input_shape=(1,))])
350+
model.compile(optimizer="adam", loss="mse")
351+
model.fit(dist_dataset, epochs=1)
352+
history = model.fit(dist_dataset, epochs=1)
353+
self.assertIn("loss", history.history)

0 commit comments

Comments
 (0)