Skip to content

Commit 7ebc3c0

Browse files
committed
Rename is_sparse_array -> is_pydata_sparse.
1 parent b92a35c commit 7ebc3c0

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

array_api_compat/common/_helpers.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def is_numpy_array(x):
5050
is_torch_array
5151
is_dask_array
5252
is_jax_array
53-
is_sparse_array
53+
is_pydata_sparse
5454
"""
5555
# Avoid importing NumPy if it isn't already
5656
if 'numpy' not in sys.modules:
@@ -80,7 +80,7 @@ def is_cupy_array(x):
8080
is_torch_array
8181
is_dask_array
8282
is_jax_array
83-
is_sparse_array
83+
is_pydata_sparse
8484
"""
8585
# Avoid importing NumPy if it isn't already
8686
if 'cupy' not in sys.modules:
@@ -107,7 +107,7 @@ def is_torch_array(x):
107107
is_cupy_array
108108
is_dask_array
109109
is_jax_array
110-
is_sparse_array
110+
is_pydata_sparse
111111
"""
112112
# Avoid importing torch if it isn't already
113113
if 'torch' not in sys.modules:
@@ -134,7 +134,7 @@ def is_dask_array(x):
134134
is_cupy_array
135135
is_torch_array
136136
is_jax_array
137-
is_sparse_array
137+
is_pydata_sparse
138138
"""
139139
# Avoid importing dask if it isn't already
140140
if 'dask.array' not in sys.modules:
@@ -161,7 +161,7 @@ def is_jax_array(x):
161161
is_cupy_array
162162
is_torch_array
163163
is_dask_array
164-
is_sparse_array
164+
is_pydata_sparse
165165
"""
166166
# Avoid importing jax if it isn't already
167167
if 'jax' not in sys.modules:
@@ -172,7 +172,7 @@ def is_jax_array(x):
172172
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
173173

174174

175-
def is_sparse_array(x) -> bool:
175+
def is_pydata_sparse(x) -> bool:
176176
"""
177177
Return True if `x` is an array from the `sparse` package.
178178
@@ -219,7 +219,7 @@ def is_array_api_obj(x):
219219
or is_torch_array(x) \
220220
or is_dask_array(x) \
221221
or is_jax_array(x) \
222-
or is_sparse_array(x) \
222+
or is_pydata_sparse(x) \
223223
or hasattr(x, '__array_namespace__')
224224

225225
def _check_api_version(api_version):
@@ -288,7 +288,7 @@ def your_function(x, y):
288288
is_torch_array
289289
is_dask_array
290290
is_jax_array
291-
is_sparse_array
291+
is_pydata_sparse
292292
293293
"""
294294
if use_compat not in [None, True, False]:
@@ -348,7 +348,7 @@ def your_function(x, y):
348348
# not have a wrapper submodule for it.
349349
import jax.experimental.array_api as jnp
350350
namespaces.add(jnp)
351-
elif is_sparse_array(x):
351+
elif is_pydata_sparse(x):
352352
if use_compat is True:
353353
_check_api_version(api_version)
354354
raise ValueError("`sparse` does not have an array-api-compat wrapper")
@@ -451,7 +451,7 @@ def device(x: Array, /) -> Device:
451451
return x.device()
452452
else:
453453
return x.device
454-
elif is_sparse_array(x):
454+
elif is_pydata_sparse(x):
455455
# `sparse` will gain `.device`, so check for this first.
456456
x_device = getattr(x, 'device', None)
457457
if x_device is not None:
@@ -583,7 +583,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
583583
# This import adds to_device to x
584584
import jax.experimental.array_api # noqa: F401
585585
return x.to_device(device, stream=stream)
586-
elif is_sparse_array(x) and device == _device(x):
586+
elif is_pydata_sparse(x) and device == _device(x):
587587
# Perform trivial check to return the same array if
588588
# device is same instead of err-ing.
589589
return x
@@ -613,7 +613,7 @@ def size(x):
613613
"is_jax_array",
614614
"is_numpy_array",
615615
"is_torch_array",
616-
"is_sparse_array",
616+
"is_pydata_sparse",
617617
"size",
618618
"to_device",
619619
]

0 commit comments

Comments
 (0)