@@ -50,7 +50,7 @@ def is_numpy_array(x):
50
50
is_torch_array
51
51
is_dask_array
52
52
is_jax_array
53
- is_pydata_sparse
53
+ is_pydata_sparse_array
54
54
"""
55
55
# Avoid importing NumPy if it isn't already
56
56
if 'numpy' not in sys .modules :
@@ -80,7 +80,7 @@ def is_cupy_array(x):
80
80
is_torch_array
81
81
is_dask_array
82
82
is_jax_array
83
- is_pydata_sparse
83
+ is_pydata_sparse_array
84
84
"""
85
85
# Avoid importing NumPy if it isn't already
86
86
if 'cupy' not in sys .modules :
@@ -107,7 +107,7 @@ def is_torch_array(x):
107
107
is_cupy_array
108
108
is_dask_array
109
109
is_jax_array
110
- is_pydata_sparse
110
+ is_pydata_sparse_array
111
111
"""
112
112
# Avoid importing torch if it isn't already
113
113
if 'torch' not in sys .modules :
@@ -134,7 +134,7 @@ def is_dask_array(x):
134
134
is_cupy_array
135
135
is_torch_array
136
136
is_jax_array
137
- is_pydata_sparse
137
+ is_pydata_sparse_array
138
138
"""
139
139
# Avoid importing dask if it isn't already
140
140
if 'dask.array' not in sys .modules :
@@ -161,7 +161,7 @@ def is_jax_array(x):
161
161
is_cupy_array
162
162
is_torch_array
163
163
is_dask_array
164
- is_pydata_sparse
164
+ is_pydata_sparse_array
165
165
"""
166
166
# Avoid importing jax if it isn't already
167
167
if 'jax' not in sys .modules :
@@ -172,7 +172,7 @@ def is_jax_array(x):
172
172
return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
173
173
174
174
175
- def is_pydata_sparse (x ) -> bool :
175
+ def is_pydata_sparse_array (x ) -> bool :
176
176
"""
177
177
Return True if `x` is an array from the `sparse` package.
178
178
@@ -219,7 +219,7 @@ def is_array_api_obj(x):
219
219
or is_torch_array (x ) \
220
220
or is_dask_array (x ) \
221
221
or is_jax_array (x ) \
222
- or is_pydata_sparse (x ) \
222
+ or is_pydata_sparse_array (x ) \
223
223
or hasattr (x , '__array_namespace__' )
224
224
225
225
def _check_api_version (api_version ):
@@ -288,7 +288,7 @@ def your_function(x, y):
288
288
is_torch_array
289
289
is_dask_array
290
290
is_jax_array
291
- is_pydata_sparse
291
+ is_pydata_sparse_array
292
292
293
293
"""
294
294
if use_compat not in [None , True , False ]:
@@ -348,7 +348,7 @@ def your_function(x, y):
348
348
# not have a wrapper submodule for it.
349
349
import jax .experimental .array_api as jnp
350
350
namespaces .add (jnp )
351
- elif is_pydata_sparse (x ):
351
+ elif is_pydata_sparse_array (x ):
352
352
if use_compat is True :
353
353
_check_api_version (api_version )
354
354
raise ValueError ("`sparse` does not have an array-api-compat wrapper" )
@@ -451,7 +451,7 @@ def device(x: Array, /) -> Device:
451
451
return x .device ()
452
452
else :
453
453
return x .device
454
- elif is_pydata_sparse (x ):
454
+ elif is_pydata_sparse_array (x ):
455
455
# `sparse` will gain `.device`, so check for this first.
456
456
x_device = getattr (x , 'device' , None )
457
457
if x_device is not None :
@@ -583,7 +583,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
583
583
# This import adds to_device to x
584
584
import jax .experimental .array_api # noqa: F401
585
585
return x .to_device (device , stream = stream )
586
- elif is_pydata_sparse (x ) and device == _device (x ):
586
+ elif is_pydata_sparse_array (x ) and device == _device (x ):
587
587
# Perform trivial check to return the same array if
588
588
# device is same instead of err-ing.
589
589
return x
@@ -613,7 +613,7 @@ def size(x):
613
613
"is_jax_array" ,
614
614
"is_numpy_array" ,
615
615
"is_torch_array" ,
616
- "is_pydata_sparse " ,
616
+ "is_pydata_sparse_array " ,
617
617
"size" ,
618
618
"to_device" ,
619
619
]
0 commit comments