Skip to content

Commit d26144d

Browse files
authored
cast numpy scalars to arrays in as_compatible_data (#9403)
* also call `np.asarray` on numpy scalars * check that numpy scalars are properly casted to arrays * don't allow `numpy.ndarray` subclasses * comment on the purpose of the explicit isinstance and `np.asarray`
1 parent 01206da commit d26144d

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

xarray/core/variable.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,14 @@ def convert_non_numpy_type(data):
320320
else:
321321
data = np.asarray(data)
322322

323-
if not isinstance(data, np.ndarray) and (
323+
# immediately return array-like types except `numpy.ndarray` subclasses and `numpy` scalars
324+
if not isinstance(data, np.ndarray | np.generic) and (
324325
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
325326
):
326327
return cast("T_DuckArray", data)
327328

328-
# validate whether the data is valid data types.
329+
# validate whether the data is valid data types. Also, explicitly cast `numpy`
330+
# subclasses and `numpy` scalars to `numpy.ndarray`
329331
data = np.asarray(data)
330332

331333
if data.dtype.kind in "OMm":

xarray/tests/test_variable.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -2585,7 +2585,12 @@ def test_unchanged_types(self):
25852585
assert source_ndarray(x) is source_ndarray(as_compatible_data(x))
25862586

25872587
def test_converted_types(self):
2588-
for input_array in [[[0, 1, 2]], pd.DataFrame([[0, 1, 2]])]:
2588+
for input_array in [
2589+
[[0, 1, 2]],
2590+
pd.DataFrame([[0, 1, 2]]),
2591+
np.float64(1.4),
2592+
np.str_("abc"),
2593+
]:
25892594
actual = as_compatible_data(input_array)
25902595
assert_array_equal(np.asarray(input_array), actual)
25912596
assert np.ndarray is type(actual)

0 commit comments

Comments
 (0)