Skip to content

Commit ecb2c29

Browse files
committed
Split tensor_generator from iter_batches into separate modules
1 parent 399fc4c commit ecb2c29

File tree

4 files changed

+159
-156
lines changed

4 files changed

+159
-156
lines changed

tiledb/ml/readers/_batch_utils.py

+1-154
Original file line numberDiff line numberDiff line change
@@ -1,158 +1,5 @@
1-
from abc import ABC, abstractmethod
21
from dataclasses import dataclass
3-
from typing import Generic, Iterator, Optional, Sequence, Type, TypeVar, Union
4-
5-
import numpy as np
6-
import scipy.sparse as sp
7-
8-
import tiledb
9-
10-
Tensor = TypeVar("Tensor")
11-
12-
13-
class TileDBTensorGenerator(ABC, Generic[Tensor]):
14-
"""Base class for generating tensors read from a TileDB array."""
15-
16-
def __init__(self, array: tiledb.Array, attrs: Sequence[str]) -> None:
17-
"""
18-
:param array: TileDB array to read from.
19-
:param attrs: Attribute names of array to read.
20-
"""
21-
self._query = array.query(attrs=attrs)
22-
23-
@abstractmethod
24-
def read_buffer(self, array_slice: slice) -> None:
25-
"""
26-
Read an array slice and save it as the current buffer.
27-
28-
:param array_slice: Requested array slice.
29-
"""
30-
31-
@abstractmethod
32-
def iter_tensors(self, buffer_slice: slice) -> Iterator[Tensor]:
33-
"""
34-
Return an iterator of tensors for the given slice, one tensor per attribute
35-
36-
Must be called after `read_buffer`.
37-
38-
:param buffer_slice: Slice of the current buffer to convert to tensors.
39-
"""
40-
41-
42-
class TileDBNumpyGenerator(TileDBTensorGenerator[np.ndarray]):
43-
def read_buffer(self, array_slice: slice) -> None:
44-
self._buf_arrays = tuple(self._query[array_slice].values())
45-
46-
def iter_tensors(self, buffer_slice: slice) -> Iterator[np.ndarray]:
47-
for buf_array in self._buf_arrays:
48-
yield buf_array[buffer_slice]
49-
50-
51-
class SparseTileDBTensorGenerator(TileDBTensorGenerator[Tensor]):
52-
def __init__(self, array: tiledb.Array, attrs: Sequence[str]) -> None:
53-
schema = array.schema
54-
if schema.ndim != 2:
55-
raise NotImplementedError("Only 2D sparse tensors are currently supported")
56-
self._row_dim = schema.domain.dim(0).name
57-
self._col_dim = schema.domain.dim(1).name
58-
self._row_shape = schema.shape[1:]
59-
self._attr_dtypes = tuple(schema.attr(attr).dtype for attr in attrs)
60-
super().__init__(array, attrs)
61-
62-
def read_buffer(self, array_slice: slice) -> None:
63-
buffer = self._query[array_slice]
64-
# COO to CSR transformation for batching and row slicing
65-
row = buffer.pop(self._row_dim)
66-
col = buffer.pop(self._col_dim)
67-
# Normalize indices: We want the coords indices to be in the [0, array_slice size]
68-
# range. If we do not normalize the sparse tensor is being created but with a
69-
# dimension [0, max(coord_index)], which is overkill
70-
start_offset = array_slice.start
71-
stop_offset = array_slice.stop
72-
shape = (stop_offset - start_offset, *self._row_shape)
73-
self._buf_csrs = tuple(
74-
sp.csr_matrix((data, (row - start_offset, col)), shape=shape)
75-
for data in buffer.values()
76-
)
77-
78-
def iter_tensors(self, buffer_slice: slice) -> Iterator[Tensor]:
79-
for buf_csr, dtype in zip(self._buf_csrs, self._attr_dtypes):
80-
batch_csr = buf_csr[buffer_slice]
81-
batch_coo = batch_csr.tocoo()
82-
data = batch_coo.data
83-
coords = np.stack((batch_coo.row, batch_coo.col), axis=-1)
84-
dense_shape = (batch_csr.shape[0], *self._row_shape)
85-
yield self._tensor_from_coo(data, coords, dense_shape, dtype)
86-
87-
@staticmethod
88-
@abstractmethod
89-
def _tensor_from_coo(
90-
data: np.ndarray,
91-
coords: np.ndarray,
92-
dense_shape: Sequence[int],
93-
dtype: np.dtype,
94-
) -> Tensor:
95-
"""Convert a scipy.sparse.coo_matrix to a Tensor"""
96-
97-
98-
DenseTensor = TypeVar("DenseTensor")
99-
SparseTensor = TypeVar("SparseTensor")
100-
101-
102-
def tensor_generator(
103-
x_array: tiledb.Array,
104-
y_array: tiledb.Array,
105-
x_buffer_size: int,
106-
y_buffer_size: int,
107-
x_attrs: Sequence[str],
108-
y_attrs: Sequence[str],
109-
start_offset: int = 0,
110-
stop_offset: int = 0,
111-
dense_tensor_generator_cls: Type[
112-
TileDBTensorGenerator[DenseTensor]
113-
] = TileDBNumpyGenerator,
114-
sparse_tensor_generator_cls: Type[
115-
TileDBTensorGenerator[SparseTensor]
116-
] = SparseTileDBTensorGenerator,
117-
) -> Iterator[Sequence[Union[DenseTensor, SparseTensor]]]:
118-
"""
119-
Generator for batches of tensors.
120-
121-
Each yielded batch is a sequence of N tensors of x_array followed by M tensors
122-
of y_array, where `N == len(x_attrs)` and `M == len(y_attrs)`.
123-
124-
:param x_array: TileDB array of the features.
125-
:param y_array: TileDB array of the labels.
126-
:param x_buffer_size: Number of rows to read at a time from x_array.
127-
:param y_buffer_size: Number of rows to read at a time from y_array.
128-
:param x_attrs: Attribute names of x_array.
129-
:param y_attrs: Attribute names of y_array.
130-
:param start_offset: Start row offset; defaults to 0.
131-
:param stop_offset: Stop row offset; defaults to number of rows.
132-
:param dense_tensor_generator_cls: Dense tensor generator type.
133-
:param sparse_tensor_generator_cls: Sparse tensor generator type.
134-
"""
135-
136-
def get_buffer_size_generator(
137-
array: tiledb.Array, attrs: Sequence[str]
138-
) -> Union[TileDBTensorGenerator[DenseTensor], TileDBTensorGenerator[SparseTensor]]:
139-
if array.schema.sparse:
140-
return sparse_tensor_generator_cls(array, attrs)
141-
else:
142-
return dense_tensor_generator_cls(array, attrs)
143-
144-
x_gen = get_buffer_size_generator(x_array, x_attrs)
145-
y_gen = get_buffer_size_generator(y_array, y_attrs)
146-
if not stop_offset:
147-
stop_offset = x_array.shape[0]
148-
for batch in iter_batches(x_buffer_size, y_buffer_size, start_offset, stop_offset):
149-
if batch.x_read_slice:
150-
x_gen.read_buffer(batch.x_read_slice)
151-
if batch.y_read_slice:
152-
y_gen.read_buffer(batch.y_read_slice)
153-
x_tensors = x_gen.iter_tensors(batch.x_buffer_slice)
154-
y_tensors = y_gen.iter_tensors(batch.y_buffer_slice)
155-
yield (*x_tensors, *y_tensors)
2+
from typing import Iterator, Optional
1563

1574

1585
@dataclass(frozen=True, repr=False)

tiledb/ml/readers/_tensor_gen.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Generic, Iterator, Sequence, Type, TypeVar, Union
3+
4+
import numpy as np
5+
import scipy.sparse as sp
6+
7+
import tiledb
8+
9+
from ._batch_utils import iter_batches
10+
11+
Tensor = TypeVar("Tensor")
12+
13+
14+
class TileDBTensorGenerator(ABC, Generic[Tensor]):
15+
"""Base class for generating tensors read from a TileDB array."""
16+
17+
def __init__(self, array: tiledb.Array, attrs: Sequence[str]) -> None:
18+
"""
19+
:param array: TileDB array to read from.
20+
:param attrs: Attribute names of array to read.
21+
"""
22+
self._query = array.query(attrs=attrs)
23+
24+
@abstractmethod
25+
def read_buffer(self, array_slice: slice) -> None:
26+
"""
27+
Read an array slice and save it as the current buffer.
28+
29+
:param array_slice: Requested array slice.
30+
"""
31+
32+
@abstractmethod
33+
def iter_tensors(self, buffer_slice: slice) -> Iterator[Tensor]:
34+
"""
35+
Return an iterator of tensors for the given slice, one tensor per attribute
36+
37+
Must be called after `read_buffer`.
38+
39+
:param buffer_slice: Slice of the current buffer to convert to tensors.
40+
"""
41+
42+
43+
class TileDBNumpyGenerator(TileDBTensorGenerator[np.ndarray]):
44+
def read_buffer(self, array_slice: slice) -> None:
45+
self._buf_arrays = tuple(self._query[array_slice].values())
46+
47+
def iter_tensors(self, buffer_slice: slice) -> Iterator[np.ndarray]:
48+
for buf_array in self._buf_arrays:
49+
yield buf_array[buffer_slice]
50+
51+
52+
class SparseTileDBTensorGenerator(TileDBTensorGenerator[Tensor]):
53+
def __init__(self, array: tiledb.Array, attrs: Sequence[str]) -> None:
54+
schema = array.schema
55+
if schema.ndim != 2:
56+
raise NotImplementedError("Only 2D sparse tensors are currently supported")
57+
self._row_dim = schema.domain.dim(0).name
58+
self._col_dim = schema.domain.dim(1).name
59+
self._row_shape = schema.shape[1:]
60+
self._attr_dtypes = tuple(schema.attr(attr).dtype for attr in attrs)
61+
super().__init__(array, attrs)
62+
63+
def read_buffer(self, array_slice: slice) -> None:
64+
buffer = self._query[array_slice]
65+
# COO to CSR transformation for batching and row slicing
66+
row = buffer.pop(self._row_dim)
67+
col = buffer.pop(self._col_dim)
68+
# Normalize indices: We want the coords indices to be in the [0, array_slice size]
69+
# range. If we do not normalize the sparse tensor is being created but with a
70+
# dimension [0, max(coord_index)], which is overkill
71+
start_offset = array_slice.start
72+
stop_offset = array_slice.stop
73+
shape = (stop_offset - start_offset, *self._row_shape)
74+
self._buf_csrs = tuple(
75+
sp.csr_matrix((data, (row - start_offset, col)), shape=shape)
76+
for data in buffer.values()
77+
)
78+
79+
def iter_tensors(self, buffer_slice: slice) -> Iterator[Tensor]:
80+
for buf_csr, dtype in zip(self._buf_csrs, self._attr_dtypes):
81+
batch_csr = buf_csr[buffer_slice]
82+
batch_coo = batch_csr.tocoo()
83+
data = batch_coo.data
84+
coords = np.stack((batch_coo.row, batch_coo.col), axis=-1)
85+
dense_shape = (batch_csr.shape[0], *self._row_shape)
86+
yield self._tensor_from_coo(data, coords, dense_shape, dtype)
87+
88+
@staticmethod
89+
@abstractmethod
90+
def _tensor_from_coo(
91+
data: np.ndarray,
92+
coords: np.ndarray,
93+
dense_shape: Sequence[int],
94+
dtype: np.dtype,
95+
) -> Tensor:
96+
"""Convert a scipy.sparse.coo_matrix to a Tensor"""
97+
98+
99+
DenseTensor = TypeVar("DenseTensor")
100+
SparseTensor = TypeVar("SparseTensor")
101+
102+
103+
def tensor_generator(
104+
x_array: tiledb.Array,
105+
y_array: tiledb.Array,
106+
x_buffer_size: int,
107+
y_buffer_size: int,
108+
x_attrs: Sequence[str],
109+
y_attrs: Sequence[str],
110+
start_offset: int = 0,
111+
stop_offset: int = 0,
112+
dense_tensor_generator_cls: Type[
113+
TileDBTensorGenerator[DenseTensor]
114+
] = TileDBNumpyGenerator,
115+
sparse_tensor_generator_cls: Type[
116+
TileDBTensorGenerator[SparseTensor]
117+
] = SparseTileDBTensorGenerator,
118+
) -> Iterator[Sequence[Union[DenseTensor, SparseTensor]]]:
119+
"""
120+
Generator for batches of tensors.
121+
122+
Each yielded batch is a sequence of N tensors of x_array followed by M tensors
123+
of y_array, where `N == len(x_attrs)` and `M == len(y_attrs)`.
124+
125+
:param x_array: TileDB array of the features.
126+
:param y_array: TileDB array of the labels.
127+
:param x_buffer_size: Number of rows to read at a time from x_array.
128+
:param y_buffer_size: Number of rows to read at a time from y_array.
129+
:param x_attrs: Attribute names of x_array.
130+
:param y_attrs: Attribute names of y_array.
131+
:param start_offset: Start row offset; defaults to 0.
132+
:param stop_offset: Stop row offset; defaults to number of rows.
133+
:param dense_tensor_generator_cls: Dense tensor generator type.
134+
:param sparse_tensor_generator_cls: Sparse tensor generator type.
135+
"""
136+
137+
def get_buffer_size_generator(
138+
array: tiledb.Array, attrs: Sequence[str]
139+
) -> Union[TileDBTensorGenerator[DenseTensor], TileDBTensorGenerator[SparseTensor]]:
140+
if array.schema.sparse:
141+
return sparse_tensor_generator_cls(array, attrs)
142+
else:
143+
return dense_tensor_generator_cls(array, attrs)
144+
145+
x_gen = get_buffer_size_generator(x_array, x_attrs)
146+
y_gen = get_buffer_size_generator(y_array, y_attrs)
147+
if not stop_offset:
148+
stop_offset = x_array.shape[0]
149+
for batch in iter_batches(x_buffer_size, y_buffer_size, start_offset, stop_offset):
150+
if batch.x_read_slice:
151+
x_gen.read_buffer(batch.x_read_slice)
152+
if batch.y_read_slice:
153+
y_gen.read_buffer(batch.y_read_slice)
154+
x_tensors = x_gen.iter_tensors(batch.x_buffer_slice)
155+
y_tensors = y_gen.iter_tensors(batch.y_buffer_slice)
156+
yield (*x_tensors, *y_tensors)

tiledb/ml/readers/pytorch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
import tiledb
1212

13-
from ._batch_utils import SparseTileDBTensorGenerator, tensor_generator
1413
from ._buffer_utils import get_attr_names, get_buffer_size
14+
from ._tensor_gen import SparseTileDBTensorGenerator, tensor_generator
1515

1616

1717
def PyTorchTileDBDataLoader(

tiledb/ml/readers/tensorflow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
import tiledb
1010

11-
from ._batch_utils import SparseTileDBTensorGenerator, tensor_generator
1211
from ._buffer_utils import get_attr_names, get_buffer_size
12+
from ._tensor_gen import SparseTileDBTensorGenerator, tensor_generator
1313

1414
# TODO: We have to track the following issues:
1515
# - https://github.com/tensorflow/tensorflow/issues/47532

0 commit comments

Comments
 (0)