Skip to content

Commit 1e4ffe8

Browse files
authored
Read only the necessary dimensions and attributes (#113)
* Replace BaseBatch.set_buffer_offset with read_buffer * Read only the necessary dimensions and attributes * Replace immutable lists with tuples
1 parent 7e155c5 commit 1e4ffe8

File tree

1 file changed

+30
-24
lines changed

1 file changed

+30
-24
lines changed

tiledb/ml/readers/_batch_utils.py

+30-24
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Generic, Iterator, Mapping, Optional, Sequence, Type, TypeVar, Union
4+
from concurrent import futures
5+
from typing import Generic, Iterator, Optional, Sequence, Type, TypeVar, Union
66

77
import numpy as np
88
import scipy.sparse as sp
@@ -25,18 +25,18 @@ def __init__(
2525
self._batch_size = batch_size
2626

2727
@abstractmethod
28-
def set_buffer_offset(self, buffer: Mapping[str, np.ndarray], offset: int) -> None:
29-
"""Set the current buffer from which subsequent batches are to be read.
28+
def read_buffer(self, array: tiledb.Array, buffer_slice: slice) -> None:
29+
"""Read a slice from a TileDB array into a buffer.
3030
31-
:param buffer: Mapping of attribute names to numpy arrays.
32-
:param offset: Start offset of the buffer in the TileDB array.
31+
:param array: TileDB array to read from.
32+
:param buffer_slice: Slice of the array to read.
3333
"""
3434

3535
@abstractmethod
3636
def set_batch_slice(self, batch_slice: slice) -> None:
37-
"""Set the current batch as a slice of the set buffer.
37+
"""Set the current batch as a slice of the read buffer.
3838
39-
Must be called after `set_buffer_offset`.
39+
Must be called after `read_buffer`.
4040
4141
:param batch_slice: Slice of the buffer to be used as the current batch.
4242
"""
@@ -62,12 +62,14 @@ def __len__(self) -> int:
6262

6363

6464
class BaseDenseBatch(BaseBatch[Tensor]):
65-
def set_buffer_offset(self, buffer: Mapping[str, np.ndarray], offset: int) -> None:
66-
self._buffer = buffer
65+
def read_buffer(self, array: tiledb.Array, buffer_slice: slice) -> None:
66+
self._buffer = array.query(dims=(), attrs=self._attrs)[buffer_slice]
6767

6868
def set_batch_slice(self, batch_slice: slice) -> None:
69-
assert hasattr(self, "_buffer"), "set_buffer_offset() not called"
70-
self._attr_batches = [self._buffer[attr][batch_slice] for attr in self._attrs]
69+
assert hasattr(self, "_buffer"), "read_buffer() not called"
70+
self._attr_batches = tuple(
71+
self._buffer[attr][batch_slice] for attr in self._attrs
72+
)
7173

7274
def iter_tensors(self, perm_idxs: Optional[np.ndarray] = None) -> Iterator[Tensor]:
7375
assert hasattr(self, "_attr_batches"), "set_batch_slice() not called"
@@ -103,20 +105,24 @@ def __init__(
103105
self._dense_shape = (batch_size, schema.shape[1])
104106
self._attr_dtypes = tuple(schema.attr(attr).dtype for attr in self._attrs)
105107

106-
def set_buffer_offset(self, buffer: Mapping[str, np.ndarray], offset: int) -> None:
108+
def read_buffer(self, array: tiledb.Array, buffer_slice: slice) -> None:
109+
buffer = array.query(attrs=self._attrs)[buffer_slice]
107110
# COO to CSR transformation for batching and row slicing
108111
row = buffer[self._row_dim]
109112
col = buffer[self._col_dim]
110113
# Normalize indices: We want the coords indices to be in the [0, batch_size]
111114
# range. If we do not normalize the sparse tensor is being created but with a
112115
# dimension [0, max(coord_index)], which is overkill
113-
self._buffer_csrs = [
116+
offset = buffer_slice.start
117+
self._buffer_csrs = tuple(
114118
sp.csr_matrix((buffer[attr], (row - offset, col))) for attr in self._attrs
115-
]
119+
)
116120

117121
def set_batch_slice(self, batch_slice: slice) -> None:
118-
assert hasattr(self, "_buffer_csrs"), "set_buffer_offset() not called"
119-
self._batch_csrs = [buffer_csr[batch_slice] for buffer_csr in self._buffer_csrs]
122+
assert hasattr(self, "_buffer_csrs"), "read_buffer() not called"
123+
self._batch_csrs = tuple(
124+
buffer_csr[batch_slice] for buffer_csr in self._buffer_csrs
125+
)
120126

121127
def iter_tensors(self, perm_idxs: Optional[np.ndarray] = None) -> Iterator[Tensor]:
122128
assert hasattr(self, "_batch_csrs"), "set_batch_slice() not called"
@@ -208,15 +214,15 @@ def batch_factory(
208214

209215
x_batch = batch_factory(x_array.schema, x_attrs)
210216
y_batch = batch_factory(y_array.schema, y_attrs)
211-
with ThreadPoolExecutor(max_workers=2) as executor:
217+
with futures.ThreadPoolExecutor(max_workers=2) as executor:
212218
for offset in range(start_offset, stop_offset, buffer_size):
213-
x_buffer, y_buffer = executor.map(
214-
lambda array: array[offset : offset + buffer_size], # type: ignore
215-
(x_array, y_array),
219+
buffer_slice = slice(offset, offset + buffer_size)
220+
futures.wait(
221+
(
222+
executor.submit(x_batch.read_buffer, x_array, buffer_slice),
223+
executor.submit(y_batch.read_buffer, y_array, buffer_slice),
224+
)
216225
)
217-
x_batch.set_buffer_offset(x_buffer, offset)
218-
y_batch.set_buffer_offset(y_buffer, offset)
219-
220226
# Split the buffer_size into batch_size chunks
221227
batch_offsets = np.arange(
222228
0, min(buffer_size, stop_offset - offset), batch_size

0 commit comments

Comments
 (0)