Skip to content

Commit e38a75b

Browse files
committed
Inline get_buffer_size_generator and rename some names
1 parent ecb2c29 commit e38a75b

File tree

3 files changed

+24
-31
lines changed

3 files changed

+24
-31
lines changed

Diff for: tiledb/ml/readers/_tensor_gen.py

+18-23
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def iter_tensors(self, buffer_slice: slice) -> Iterator[np.ndarray]:
4949
yield buf_array[buffer_slice]
5050

5151

52-
class SparseTileDBTensorGenerator(TileDBTensorGenerator[Tensor]):
52+
class TileDBSparseTensorGenerator(TileDBTensorGenerator[Tensor]):
5353
def __init__(self, array: tiledb.Array, attrs: Sequence[str]) -> None:
5454
schema = array.schema
5555
if schema.ndim != 2:
@@ -96,8 +96,8 @@ def _tensor_from_coo(
9696
"""Convert a scipy.sparse.coo_matrix to a Tensor"""
9797

9898

99-
DenseTensor = TypeVar("DenseTensor")
100-
SparseTensor = TypeVar("SparseTensor")
99+
DT = TypeVar("DT")
100+
ST = TypeVar("ST")
101101

102102

103103
def tensor_generator(
@@ -109,13 +109,9 @@ def tensor_generator(
109109
y_attrs: Sequence[str],
110110
start_offset: int = 0,
111111
stop_offset: int = 0,
112-
dense_tensor_generator_cls: Type[
113-
TileDBTensorGenerator[DenseTensor]
114-
] = TileDBNumpyGenerator,
115-
sparse_tensor_generator_cls: Type[
116-
TileDBTensorGenerator[SparseTensor]
117-
] = SparseTileDBTensorGenerator,
118-
) -> Iterator[Sequence[Union[DenseTensor, SparseTensor]]]:
112+
dense_generator_cls: Type[TileDBTensorGenerator[DT]] = TileDBNumpyGenerator,
113+
sparse_generator_cls: Type[TileDBTensorGenerator[ST]] = TileDBSparseTensorGenerator,
114+
) -> Iterator[Sequence[Union[DT, ST]]]:
119115
"""
120116
Generator for batches of tensors.
121117
@@ -130,20 +126,19 @@ def tensor_generator(
130126
:param y_attrs: Attribute names of y_array.
131127
:param start_offset: Start row offset; defaults to 0.
132128
:param stop_offset: Stop row offset; defaults to number of rows.
133-
:param dense_tensor_generator_cls: Dense tensor generator type.
134-
:param sparse_tensor_generator_cls: Sparse tensor generator type.
129+
:param dense_generator_cls: Dense tensor generator type.
130+
:param sparse_generator_cls: Sparse tensor generator type.
135131
"""
136-
137-
def get_buffer_size_generator(
138-
array: tiledb.Array, attrs: Sequence[str]
139-
) -> Union[TileDBTensorGenerator[DenseTensor], TileDBTensorGenerator[SparseTensor]]:
140-
if array.schema.sparse:
141-
return sparse_tensor_generator_cls(array, attrs)
142-
else:
143-
return dense_tensor_generator_cls(array, attrs)
144-
145-
x_gen = get_buffer_size_generator(x_array, x_attrs)
146-
y_gen = get_buffer_size_generator(y_array, y_attrs)
132+
x_gen: Union[TileDBTensorGenerator[DT], TileDBTensorGenerator[ST]] = (
133+
sparse_generator_cls(x_array, x_attrs)
134+
if x_array.schema.sparse
135+
else dense_generator_cls(x_array, x_attrs)
136+
)
137+
y_gen: Union[TileDBTensorGenerator[DT], TileDBTensorGenerator[ST]] = (
138+
sparse_generator_cls(y_array, y_attrs)
139+
if y_array.schema.sparse
140+
else dense_generator_cls(y_array, y_attrs)
141+
)
147142
if not stop_offset:
148143
stop_offset = x_array.shape[0]
149144
for batch in iter_batches(x_buffer_size, y_buffer_size, start_offset, stop_offset):

Diff for: tiledb/ml/readers/pytorch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import tiledb
1212

1313
from ._buffer_utils import get_attr_names, get_buffer_size
14-
from ._tensor_gen import SparseTileDBTensorGenerator, tensor_generator
14+
from ._tensor_gen import TileDBSparseTensorGenerator, tensor_generator
1515

1616

1717
def PyTorchTileDBDataLoader(
@@ -92,7 +92,7 @@ def __init__(
9292
y_buffer_size=get_buffer_size(y_array, y_attrs, buffer_bytes),
9393
x_attrs=x_attrs,
9494
y_attrs=y_attrs,
95-
sparse_tensor_generator_cls=PyTorchSparseTileDBTensorGenerator,
95+
sparse_generator_cls=PyTorchSparseTensorGenerator,
9696
)
9797

9898
def __iter__(self) -> Iterator[Sequence[torch.Tensor]]:
@@ -166,7 +166,7 @@ def iter_shuffled(iterable: Iterable[T], buffer_size: int) -> Iterator[T]:
166166
yield buffer.pop()
167167

168168

169-
class PyTorchSparseTileDBTensorGenerator(SparseTileDBTensorGenerator[torch.Tensor]):
169+
class PyTorchSparseTensorGenerator(TileDBSparseTensorGenerator[torch.Tensor]):
170170
@staticmethod
171171
def _tensor_from_coo(
172172
data: np.ndarray,

Diff for: tiledb/ml/readers/tensorflow.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import tiledb
1010

1111
from ._buffer_utils import get_attr_names, get_buffer_size
12-
from ._tensor_gen import SparseTileDBTensorGenerator, tensor_generator
12+
from ._tensor_gen import TileDBSparseTensorGenerator, tensor_generator
1313

1414
# TODO: We have to track the following issues:
1515
# - https://github.com/tensorflow/tensorflow/issues/47532
@@ -56,7 +56,7 @@ def TensorflowTileDBDataset(
5656
y_buffer_size=get_buffer_size(y_array, y_attrs, buffer_bytes),
5757
x_attrs=x_attrs,
5858
y_attrs=y_attrs,
59-
sparse_tensor_generator_cls=TensorflowSparseTileDBTensorGenerator,
59+
sparse_generator_cls=TensorflowSparseTensorGenerator,
6060
),
6161
output_signature=(
6262
*_iter_tensor_specs(x_array.schema, x_attrs),
@@ -77,9 +77,7 @@ def _iter_tensor_specs(
7777
yield cls(shape=(None, *schema.shape[1:]), dtype=schema.attr(attr).dtype)
7878

7979

80-
class TensorflowSparseTileDBTensorGenerator(
81-
SparseTileDBTensorGenerator[tf.SparseTensor]
82-
):
80+
class TensorflowSparseTensorGenerator(TileDBSparseTensorGenerator[tf.SparseTensor]):
8381
@staticmethod
8482
def _tensor_from_coo(
8583
data: np.ndarray,

0 commit comments

Comments
 (0)