Skip to content

Commit 7e155c5

Browse files
committed
BaseSparseBatch fix: take into account all requested attributes, not just the first one
1 parent 2eaf47b commit 7e155c5

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

Diff for: tiledb/ml/readers/_batch_utils.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -110,33 +110,41 @@ def set_buffer_offset(self, buffer: Mapping[str, np.ndarray], offset: int) -> No
110110
# Normalize indices: We want the coords indices to be in the [0, batch_size]
111111
# range. If we do not normalize the sparse tensor is being created but with a
112112
# dimension [0, max(coord_index)], which is overkill
113-
self._buffer_csr = sp.csr_matrix((buffer[self._attrs[0]], (row - offset, col)))
113+
self._buffer_csrs = [
114+
sp.csr_matrix((buffer[attr], (row - offset, col))) for attr in self._attrs
115+
]
114116

115117
def set_batch_slice(self, batch_slice: slice) -> None:
116-
assert hasattr(self, "_buffer_csr"), "set_buffer_offset() not called"
117-
self._batch_csr = self._buffer_csr[batch_slice]
118+
assert hasattr(self, "_buffer_csrs"), "set_buffer_offset() not called"
119+
self._batch_csrs = [buffer_csr[batch_slice] for buffer_csr in self._buffer_csrs]
118120

119121
def iter_tensors(self, perm_idxs: Optional[np.ndarray] = None) -> Iterator[Tensor]:
120-
assert hasattr(self, "_batch_csr"), "set_batch_slice() not called"
122+
assert hasattr(self, "_batch_csrs"), "set_batch_slice() not called"
121123
if perm_idxs is not None:
122124
raise NotImplementedError(
123125
"within_batch_shuffle not implemented for sparse arrays"
124126
)
125-
batch_coo = self._batch_csr.tocoo()
126-
data = batch_coo.data
127-
coords = np.stack((batch_coo.row, batch_coo.col), axis=-1)
128-
for dtype in self._attr_dtypes:
127+
for batch_csr, dtype in zip(self._batch_csrs, self._attr_dtypes):
128+
batch_coo = batch_csr.tocoo()
129+
data = batch_coo.data
130+
coords = np.stack((batch_coo.row, batch_coo.col), axis=-1)
129131
yield self._tensor_from_coo(data, coords, self._dense_shape, dtype)
130132

131133
def __len__(self) -> int:
132-
assert hasattr(self, "_batch_csr"), "set_batch_slice() not called"
134+
assert hasattr(self, "_batch_csrs"), "set_batch_slice() not called"
133135
# return number of non-zero rows
134-
return int((self._batch_csr.getnnz(axis=1) > 0).sum())
136+
lengths = {
137+
int((batch_csr.getnnz(axis=1) > 0).sum()) for batch_csr in self._batch_csrs
138+
}
139+
assert len(lengths) == 1, f"Multiple different batch lengths: {lengths}"
140+
return lengths.pop()
135141

136142
def __bool__(self) -> bool:
137-
assert hasattr(self, "_batch_csr"), "set_batch_slice() not called"
143+
assert hasattr(self, "_batch_csrs"), "set_batch_slice() not called"
138144
# faster version of __len__() > 0
139-
return len(self._batch_csr.data) > 0
145+
lengths = {len(batch_csr.data) for batch_csr in self._batch_csrs}
146+
assert len(lengths) == 1, f"Multiple different batch lengths: {lengths}"
147+
return lengths.pop() > 0
140148

141149
@staticmethod
142150
@abstractmethod

0 commit comments

Comments
 (0)