|
1 | 1 | """Tests for TileDB integration with Tensorflow Data API."""
|
2 | 2 |
|
3 |
| -import os |
4 |
| - |
5 | 3 | import numpy as np
|
6 | 4 | import pytest
|
7 | 5 | import tensorflow as tf
|
8 | 6 |
|
9 |
| -from tiledb.ml.readers._batch_utils import tensor_generator |
10 |
| -from tiledb.ml.readers.tensorflow import ( |
11 |
| - TensorflowSparseTileDBTensorGenerator, |
12 |
| - TensorflowTileDBDataset, |
13 |
| -) |
| 7 | +from tiledb.ml.readers.tensorflow import TensorflowTileDBDataset |
14 | 8 |
|
15 | 9 | from .utils import (
|
16 | 10 | ingest_in_tiledb,
|
|
19 | 13 | validate_tensor_generator,
|
20 | 14 | )
|
21 | 15 |
|
22 |
| -# Suppress all Tensorflow messages |
23 |
| -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
24 |
| - |
25 |
| -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
26 |
| - |
27 | 16 |
|
28 | 17 | @pytest.mark.parametrize("num_rows", [107])
|
29 | 18 | class TestTensorflowTileDBDataset:
|
@@ -62,19 +51,6 @@ def test_dataset(
|
62 | 51 | dataset, num_attrs, x_sparse, y_sparse, x_shape, y_shape, batch_size
|
63 | 52 | )
|
64 | 53 |
|
65 |
| - # Although TensorflowTileDBDataset calls tensor_generator internally, due to |
66 |
| - # https://github.com/tensorflow/tensorflow/issues/33759 it is not reported as |
67 |
| - # covered so test it explicitly. |
68 |
| - generator = tensor_generator( |
69 |
| - buffer_bytes=buffer_bytes, |
70 |
| - sparse_tensor_generator_cls=TensorflowSparseTileDBTensorGenerator, |
71 |
| - **kwargs, |
72 |
| - ) |
73 |
| - # tensor_generator does not take batch_size parameter, so pass batch_size=num_rows |
74 |
| - validate_tensor_generator( |
75 |
| - generator, num_attrs, x_sparse, y_sparse, x_shape, y_shape, num_rows |
76 |
| - ) |
77 |
| - |
78 | 54 | @parametrize_for_dataset()
|
79 | 55 | def test_unequal_num_rows(
|
80 | 56 | self,
|
|
0 commit comments