@@ -22,11 +22,14 @@ class ArraySpec:
22
22
shape : Sequence [int ]
23
23
key_dim : int
24
24
key_dim_dtype : np .dtype
25
+ non_key_dim_dtype : np .dtype
25
26
num_fields : int
26
27
27
28
def tensor_kind (self , supports_csr : bool ) -> TensorKind :
28
29
if not self .sparse :
29
30
return TensorKind .DENSE
31
+ elif not np .issubdtype (self .non_key_dim_dtype , np .integer ):
32
+ return TensorKind .RAGGED
30
33
elif len (self .shape ) == 2 and supports_csr :
31
34
return TensorKind .SPARSE_CSR
32
35
else :
@@ -42,6 +45,7 @@ def parametrize_for_dataset(
42
45
x_key_dim = (0 , 1 ),
43
46
y_key_dim = (0 , 1 ),
44
47
key_dim_dtype = (np .dtype (np .int32 ), np .dtype ("datetime64[D]" ), np .dtype (np .bytes_ )),
48
+ non_key_dim_dtype = (np .dtype (np .int32 ), np .dtype (np .float32 )),
45
49
num_fields = (0 , 1 , 2 ),
46
50
batch_size = (8 ,),
47
51
shuffle_buffer_size = (16 ,),
@@ -57,6 +61,7 @@ def parametrize_for_dataset(
57
61
x_key_dim_ ,
58
62
y_key_dim_ ,
59
63
key_dim_dtype_ ,
64
+ non_key_dim_dtype_ ,
60
65
num_fields_ ,
61
66
batch_size_ ,
62
67
shuffle_buffer_size_ ,
@@ -69,6 +74,7 @@ def parametrize_for_dataset(
69
74
x_key_dim ,
70
75
y_key_dim ,
71
76
key_dim_dtype ,
77
+ non_key_dim_dtype ,
72
78
num_fields ,
73
79
batch_size ,
74
80
shuffle_buffer_size ,
@@ -78,9 +84,12 @@ def parametrize_for_dataset(
78
84
if not x_sparse_ or not y_sparse_ :
79
85
if not np .issubdtype (key_dim_dtype_ , np .integer ):
80
86
continue
87
+ if not np .issubdtype (non_key_dim_dtype_ , np .integer ):
88
+ continue
81
89
82
- x_spec = ArraySpec (x_sparse_ , x_shape_ , x_key_dim_ , key_dim_dtype_ , num_fields_ )
83
- y_spec = ArraySpec (y_sparse_ , y_shape_ , y_key_dim_ , key_dim_dtype_ , num_fields_ )
90
+ common_args = (key_dim_dtype_ , non_key_dim_dtype_ , num_fields_ )
91
+ x_spec = ArraySpec (x_sparse_ , x_shape_ , x_key_dim_ , * common_args )
92
+ y_spec = ArraySpec (y_sparse_ , y_shape_ , y_key_dim_ , * common_args )
84
93
argvalues .append (
85
94
(x_spec , y_spec , batch_size_ , shuffle_buffer_size_ , num_workers_ )
86
95
)
@@ -101,7 +110,7 @@ def ingest_in_tiledb(tmpdir, spec: ArraySpec):
101
110
transforms = []
102
111
for i in range (data .ndim ):
103
112
n = data .shape [i ]
104
- dtype = spec .key_dim_dtype if i == spec .key_dim else np . dtype ( "int32" )
113
+ dtype = spec .key_dim_dtype if i == spec .key_dim else spec . non_key_dim_dtype
105
114
if np .issubdtype (dtype , np .number ):
106
115
# set the domain to (-n/2, n/2) to test negative domain indexing
107
116
min_value = - (n // 2 )
@@ -216,16 +225,28 @@ def validate_tensor_generator(generator, x_spec, y_spec, batch_size, supports_cs
216
225
def _validate_tensor (tensor , spec , batch_size , supports_csr ):
217
226
tensor_kind = _get_tensor_kind (tensor )
218
227
assert tensor_kind is spec .tensor_kind (supports_csr )
219
- num_rows , * row_shape = tensor .shape
228
+
229
+ spec_row_shape = spec .shape [1 :]
230
+ if tensor_kind is not TensorKind .RAGGED :
231
+ num_rows , * row_shape = tensor .shape
232
+ assert tuple (row_shape ) == spec_row_shape
233
+ else :
234
+ # every ragged array row has at most `np.prod(spec_row_shape)` elements,
235
+ # the product of all non-key dimension sizes
236
+ row_lengths = tuple (map (len , tensor ))
237
+ assert all (row_length <= np .prod (spec_row_shape ) for row_length in row_lengths )
238
+ num_rows = len (row_lengths )
239
+
220
240
# num_rows may be less than batch_size
221
241
assert num_rows <= batch_size , (num_rows , batch_size )
222
- assert tuple (row_shape ) == spec .shape [1 :]
223
242
224
243
225
244
def _get_tensor_kind (tensor ) -> TensorKind :
226
245
if isinstance (tensor , tf .Tensor ):
227
246
return TensorKind .DENSE
228
247
if isinstance (tensor , torch .Tensor ):
248
+ if getattr (tensor , "is_nested" , False ):
249
+ return TensorKind .RAGGED
229
250
return _torch_tensor_layout_to_kind [tensor .layout ]
230
251
return _tensor_type_to_kind [type (tensor )]
231
252
@@ -236,6 +257,7 @@ def _get_tensor_kind(tensor) -> TensorKind:
236
257
scipy .sparse .coo_matrix : TensorKind .SPARSE_COO ,
237
258
scipy .sparse .csr_matrix : TensorKind .SPARSE_CSR ,
238
259
tf .SparseTensor : TensorKind .SPARSE_COO ,
260
+ tf .RaggedTensor : TensorKind .RAGGED ,
239
261
}
240
262
241
263
_torch_tensor_layout_to_kind = {
0 commit comments