@@ -49,7 +49,7 @@ def iter_tensors(self, buffer_slice: slice) -> Iterator[np.ndarray]:
49
49
yield buf_array [buffer_slice ]
50
50
51
51
52
- class SparseTileDBTensorGenerator (TileDBTensorGenerator [Tensor ]):
52
+ class TileDBSparseTensorGenerator (TileDBTensorGenerator [Tensor ]):
53
53
def __init__ (self , array : tiledb .Array , attrs : Sequence [str ]) -> None :
54
54
schema = array .schema
55
55
if schema .ndim != 2 :
@@ -96,8 +96,8 @@ def _tensor_from_coo(
96
96
"""Convert a scipy.sparse.coo_matrix to a Tensor"""
97
97
98
98
99
- DenseTensor = TypeVar ("DenseTensor " )
100
- SparseTensor = TypeVar ("SparseTensor " )
99
+ DT = TypeVar ("DT " )
100
+ ST = TypeVar ("ST " )
101
101
102
102
103
103
def tensor_generator (
@@ -109,13 +109,9 @@ def tensor_generator(
109
109
y_attrs : Sequence [str ],
110
110
start_offset : int = 0 ,
111
111
stop_offset : int = 0 ,
112
- dense_tensor_generator_cls : Type [
113
- TileDBTensorGenerator [DenseTensor ]
114
- ] = TileDBNumpyGenerator ,
115
- sparse_tensor_generator_cls : Type [
116
- TileDBTensorGenerator [SparseTensor ]
117
- ] = SparseTileDBTensorGenerator ,
118
- ) -> Iterator [Sequence [Union [DenseTensor , SparseTensor ]]]:
112
+ dense_generator_cls : Type [TileDBTensorGenerator [DT ]] = TileDBNumpyGenerator ,
113
+ sparse_generator_cls : Type [TileDBTensorGenerator [ST ]] = TileDBSparseTensorGenerator ,
114
+ ) -> Iterator [Sequence [Union [DT , ST ]]]:
119
115
"""
120
116
Generator for batches of tensors.
121
117
@@ -130,20 +126,19 @@ def tensor_generator(
130
126
:param y_attrs: Attribute names of y_array.
131
127
:param start_offset: Start row offset; defaults to 0.
132
128
:param stop_offset: Stop row offset; defaults to number of rows.
133
- :param dense_tensor_generator_cls : Dense tensor generator type.
134
- :param sparse_tensor_generator_cls : Sparse tensor generator type.
129
+ :param dense_generator_cls : Dense tensor generator type.
130
+ :param sparse_generator_cls : Sparse tensor generator type.
135
131
"""
136
-
137
- def get_buffer_size_generator (
138
- array : tiledb .Array , attrs : Sequence [str ]
139
- ) -> Union [TileDBTensorGenerator [DenseTensor ], TileDBTensorGenerator [SparseTensor ]]:
140
- if array .schema .sparse :
141
- return sparse_tensor_generator_cls (array , attrs )
142
- else :
143
- return dense_tensor_generator_cls (array , attrs )
144
-
145
- x_gen = get_buffer_size_generator (x_array , x_attrs )
146
- y_gen = get_buffer_size_generator (y_array , y_attrs )
132
+ x_gen : Union [TileDBTensorGenerator [DT ], TileDBTensorGenerator [ST ]] = (
133
+ sparse_generator_cls (x_array , x_attrs )
134
+ if x_array .schema .sparse
135
+ else dense_generator_cls (x_array , x_attrs )
136
+ )
137
+ y_gen : Union [TileDBTensorGenerator [DT ], TileDBTensorGenerator [ST ]] = (
138
+ sparse_generator_cls (y_array , y_attrs )
139
+ if y_array .schema .sparse
140
+ else dense_generator_cls (y_array , y_attrs )
141
+ )
147
142
if not stop_offset :
148
143
stop_offset = x_array .shape [0 ]
149
144
for batch in iter_batches (x_buffer_size , y_buffer_size , start_offset , stop_offset ):
0 commit comments