Skip to content

Commit 7b90a4e

Browse files
committed
refactor tests for asarrays
1 parent 273ff4e commit 7b90a4e

File tree

1 file changed

+89
-92
lines changed

1 file changed

+89
-92
lines changed

tests/test_helpers.py

+89-92
Original file line numberDiff line numberDiff line change
@@ -59,95 +59,92 @@ def test_xp(self, xp: ModuleType):
5959
xp_assert_equal(actual, expected)
6060

6161

62-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
63-
@pytest.mark.parametrize(
64-
("dtype", "b", "defined"),
65-
[
66-
# Well-defined cases of dtype promotion from Python scalar to Array
67-
# bool vs. bool
68-
("bool", True, True),
69-
# int vs. xp.*int*, xp.float*, xp.complex*
70-
("int16", 1, True),
71-
("uint8", 1, True),
72-
("float32", 1, True),
73-
("float64", 1, True),
74-
("complex64", 1, True),
75-
("complex128", 1, True),
76-
# float vs. xp.float, xp.complex
77-
("float32", 1.0, True),
78-
("float64", 1.0, True),
79-
("complex64", 1.0, True),
80-
("complex128", 1.0, True),
81-
# complex vs. xp.complex
82-
("complex64", 1.0j, True),
83-
("complex128", 1.0j, True),
84-
# Undefined cases
85-
("bool", 1, False),
86-
("int64", 1.0, False),
87-
("float64", 1.0j, False),
88-
],
89-
)
90-
def test_asarrays_array_vs_scalar(
91-
dtype: str, b: int | float | complex, defined: bool, xp: ModuleType
92-
):
93-
a = xp.asarray(1, dtype=getattr(xp, dtype))
94-
95-
xa, xb = asarrays(a, b, xp)
96-
assert xa.dtype == a.dtype
97-
if defined:
98-
assert xb.dtype == a.dtype
99-
else:
100-
assert xb.dtype == xp.asarray(b).dtype
101-
102-
xbr, xar = asarrays(b, a, xp)
103-
assert xar.dtype == xa.dtype
104-
assert xbr.dtype == xb.dtype
105-
106-
107-
def test_asarrays_scalar_vs_scalar(xp: ModuleType):
108-
a, b = asarrays(1, 2.2, xp=xp)
109-
assert a.dtype == xp.asarray(1).dtype # Default dtype
110-
assert b.dtype == xp.asarray(2.2).dtype # Default dtype; not broadcasted
111-
112-
113-
ALL_TYPES = (
114-
"int8",
115-
"int16",
116-
"int32",
117-
"int64",
118-
"uint8",
119-
"uint16",
120-
"uint32",
121-
"uint64",
122-
"float32",
123-
"float64",
124-
"complex64",
125-
"complex128",
126-
"bool",
127-
)
128-
129-
130-
@pytest.mark.parametrize("a_type", ALL_TYPES)
131-
@pytest.mark.parametrize("b_type", ALL_TYPES)
132-
def test_asarrays_array_vs_array(a_type: str, b_type: str, xp: ModuleType):
133-
"""
134-
Test that when both inputs of asarray are already Array API objects,
135-
they are returned unchanged.
136-
"""
137-
a = xp.asarray(1, dtype=getattr(xp, a_type))
138-
b = xp.asarray(1, dtype=getattr(xp, b_type))
139-
xa, xb = asarrays(a, b, xp)
140-
assert xa.dtype == a.dtype
141-
assert xb.dtype == b.dtype
142-
143-
144-
@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
145-
def test_asarrays_numpy_generics(dtype: type):
146-
"""
147-
Test special case of np.float64 and np.complex128,
148-
which are subclasses of float and complex.
149-
"""
150-
a = dtype(0)
151-
xa, xb = asarrays(a, 0, xp=np)
152-
assert xa.dtype == dtype
153-
assert xb.dtype == dtype
62+
class TestAsArrays:
63+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
64+
@pytest.mark.parametrize(
65+
("dtype", "b", "defined"),
66+
[
67+
# Well-defined cases of dtype promotion from Python scalar to Array
68+
# bool vs. bool
69+
("bool", True, True),
70+
# int vs. xp.*int*, xp.float*, xp.complex*
71+
("int16", 1, True),
72+
("uint8", 1, True),
73+
("float32", 1, True),
74+
("float64", 1, True),
75+
("complex64", 1, True),
76+
("complex128", 1, True),
77+
# float vs. xp.float, xp.complex
78+
("float32", 1.0, True),
79+
("float64", 1.0, True),
80+
("complex64", 1.0, True),
81+
("complex128", 1.0, True),
82+
# complex vs. xp.complex
83+
("complex64", 1.0j, True),
84+
("complex128", 1.0j, True),
85+
# Undefined cases
86+
("bool", 1, False),
87+
("int64", 1.0, False),
88+
("float64", 1.0j, False),
89+
],
90+
)
91+
def test_array_vs_scalar(
92+
self, dtype: str, b: int | float | complex, defined: bool, xp: ModuleType
93+
):
94+
a = xp.asarray(1, dtype=getattr(xp, dtype))
95+
96+
xa, xb = asarrays(a, b, xp)
97+
assert xa.dtype == a.dtype
98+
if defined:
99+
assert xb.dtype == a.dtype
100+
else:
101+
assert xb.dtype == xp.asarray(b).dtype
102+
103+
xbr, xar = asarrays(b, a, xp)
104+
assert xar.dtype == xa.dtype
105+
assert xbr.dtype == xb.dtype
106+
107+
def test_scalar_vs_scalar(self, xp: ModuleType):
108+
a, b = asarrays(1, 2.2, xp=xp)
109+
assert a.dtype == xp.asarray(1).dtype # Default dtype
110+
assert b.dtype == xp.asarray(2.2).dtype # Default dtype; not broadcasted
111+
112+
ALL_TYPES: tuple[str, ...] = (
113+
"int8",
114+
"int16",
115+
"int32",
116+
"int64",
117+
"uint8",
118+
"uint16",
119+
"uint32",
120+
"uint64",
121+
"float32",
122+
"float64",
123+
"complex64",
124+
"complex128",
125+
"bool",
126+
)
127+
128+
@pytest.mark.parametrize("a_type", ALL_TYPES)
129+
@pytest.mark.parametrize("b_type", ALL_TYPES)
130+
def test_array_vs_array(self, a_type: str, b_type: str, xp: ModuleType):
131+
"""
132+
Test that when both inputs of asarray are already Array API objects,
133+
they are returned unchanged.
134+
"""
135+
a = xp.asarray(1, dtype=getattr(xp, a_type))
136+
b = xp.asarray(1, dtype=getattr(xp, b_type))
137+
xa, xb = asarrays(a, b, xp)
138+
assert xa.dtype == a.dtype
139+
assert xb.dtype == b.dtype
140+
141+
@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
142+
def test_numpy_generics(self, dtype: type):
143+
"""
144+
Test special case of np.float64 and np.complex128,
145+
which are subclasses of float and complex.
146+
"""
147+
a = dtype(0)
148+
xa, xb = asarrays(a, 0, xp=np)
149+
assert xa.dtype == dtype
150+
assert xb.dtype == dtype

0 commit comments

Comments
 (0)