|
6 | 6 | import tensorflow as tf
|
7 | 7 | import torch
|
8 | 8 |
|
| 9 | +from keras.src import Sequential |
9 | 10 | from keras.src import backend
|
| 11 | +from keras.src import layers |
10 | 12 | from keras.src import testing
|
11 | 13 | from keras.src.trainers.data_adapters import tf_dataset_adapter
|
12 | 14 |
|
@@ -286,3 +288,66 @@ def test_tf_sparse_tensors(self):
|
286 | 288 | self.assertIsInstance(by, expected_class)
|
287 | 289 | self.assertEqual(bx.shape, (2, 4))
|
288 | 290 | self.assertEqual(by.shape, (2, 2))
|
| 291 | + |
| 292 | + def test_distributed_datasets_from_function_adapter_properties(self): |
| 293 | + strategy = tf.distribute.MirroredStrategy() |
| 294 | + |
| 295 | + def dataset_fn(input_context): |
| 296 | + batch_size = input_context.get_per_replica_batch_size( |
| 297 | + global_batch_size=2 |
| 298 | + ) |
| 299 | + x = tf.random.uniform((32, 4)) |
| 300 | + y = tf.random.uniform((32, 2)) |
| 301 | + return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size) |
| 302 | + |
| 303 | + dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) |
| 304 | + adapter = tf_dataset_adapter.TFDatasetAdapter(dist_dataset) |
| 305 | + self.assertEqual(adapter.num_batches, 16) |
| 306 | + self.assertIsNone(adapter.batch_size) |
| 307 | + self.assertIsNone(adapter.has_partial_batch) |
| 308 | + self.assertIsNone(adapter.partial_batch_size) |
| 309 | + |
| 310 | + if backend.backend() == "numpy": |
| 311 | + it = adapter.get_numpy_iterator() |
| 312 | + expected_class = np.ndarray |
| 313 | + elif backend.backend() == "tensorflow": |
| 314 | + it = adapter.get_tf_dataset() |
| 315 | + expected_class = tf.Tensor |
| 316 | + elif backend.backend() == "jax": |
| 317 | + it = adapter.get_jax_iterator() |
| 318 | + expected_class = np.ndarray |
| 319 | + elif backend.backend() == "torch": |
| 320 | + it = adapter.get_torch_dataloader() |
| 321 | + expected_class = torch.Tensor |
| 322 | + |
| 323 | + batch_count = 0 |
| 324 | + for batch in it: |
| 325 | + batch_count += 1 |
| 326 | + self.assertEqual(len(batch), 2) |
| 327 | + data, labels = batch |
| 328 | + self.assertIsInstance(data, expected_class) |
| 329 | + self.assertIsInstance(labels, expected_class) |
| 330 | + self.assertEqual(data.shape, (2, 4)) |
| 331 | + self.assertEqual(labels.shape, (2, 2)) |
| 332 | + |
| 333 | + self.assertEqual(batch_count, 16) |
| 334 | + |
| 335 | + @pytest.mark.requires_trainable_backend |
| 336 | + def test_distributed_datasets_from_function_model_integration(self): |
| 337 | + strategy = tf.distribute.MirroredStrategy() |
| 338 | + |
| 339 | + def dataset_fn(input_context): |
| 340 | + batch_size = input_context.get_per_replica_batch_size( |
| 341 | + global_batch_size=2 |
| 342 | + ) |
| 343 | + x = tf.random.uniform((4, 1)) |
| 344 | + y = tf.random.uniform((4, 2)) |
| 345 | + return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size) |
| 346 | + |
| 347 | + dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) |
| 348 | + |
| 349 | + model = Sequential([layers.Dense(2, input_shape=(1,))]) |
| 350 | + model.compile(optimizer="adam", loss="mse") |
| 351 | + model.fit(dist_dataset, epochs=1) |
| 352 | + history = model.fit(dist_dataset, epochs=1) |
| 353 | + self.assertIn("loss", history.history) |
0 commit comments