|
11 | 11 | cov,
|
12 | 12 | create_diagonal,
|
13 | 13 | expand_dims,
|
| 14 | + isclose, |
14 | 15 | kron,
|
15 | 16 | nunique,
|
16 | 17 | pad,
|
|
23 | 24 | from array_api_extra._lib._utils._typing import Array, Device
|
24 | 25 |
|
25 | 26 | # some xp backends are untyped
|
26 |
| -# mypy: disable-error-code=no-untyped-usage |
| 27 | +# mypy: disable-error-code=no-untyped-def |
27 | 28 |
|
28 | 29 |
|
29 | 30 | @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
|
@@ -252,6 +253,139 @@ def test_xp(self, xp: ModuleType):
|
252 | 253 | assert y.shape == (1, 1, 1, 3)
|
253 | 254 |
|
254 | 255 |
|
| 256 | +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype") |
| 257 | +class TestIsClose: |
| 258 | + # FIXME use lazywhere to avoid warnings on inf |
| 259 | + @pytest.mark.filterwarnings("ignore:invalid value encountered") |
| 260 | + @pytest.mark.parametrize( |
| 261 | + ("a", "b"), |
| 262 | + [ |
| 263 | + (0.0, 0.0), |
| 264 | + (1.0, 1.0), |
| 265 | + (1.0, 2.0), |
| 266 | + (1.0, -1.0), |
| 267 | + (100.0, 101.0), |
| 268 | + (0, 0), |
| 269 | + (1, 1), |
| 270 | + (1, 2), |
| 271 | + (1, -1), |
| 272 | + (1.0 + 1j, 1.0 + 1j), |
| 273 | + (1.0 + 1j, 1.0 - 1j), |
| 274 | + (float("inf"), float("inf")), |
| 275 | + (float("inf"), 100.0), |
| 276 | + (float("inf"), float("-inf")), |
| 277 | + (float("nan"), float("nan")), |
| 278 | + (float("nan"), 0.0), |
| 279 | + (0.0, float("nan")), |
| 280 | + (1e6, 1e6 + 1), # True - within rtol |
| 281 | + (1e6, 1e6 + 100), # False - outside rtol |
| 282 | + (1e-6, 1.1e-6), # False - outside atol |
| 283 | + (1e-7, 1.1e-7), # True - outside atol |
| 284 | + (1e6 + 0j, 1e6 + 1j), # True - within rtol |
| 285 | + (1e6 + 0j, 1e6 + 100j), # False - outside rtol |
| 286 | + ], |
| 287 | + ) |
| 288 | + def test_basic(self, a: float, b: float, xp: ModuleType): |
| 289 | + a_xp = xp.asarray(a) |
| 290 | + b_xp = xp.asarray(b) |
| 291 | + |
| 292 | + xp_assert_equal(isclose(a_xp, b_xp), xp.asarray(np.isclose(a, b))) |
| 293 | + |
| 294 | + with warnings.catch_warnings(): |
| 295 | + warnings.simplefilter("ignore") |
| 296 | + r_xp = xp.asarray(np.arange(10), dtype=a_xp.dtype) |
| 297 | + ar_xp = a_xp * r_xp |
| 298 | + br_xp = b_xp * r_xp |
| 299 | + ar_np = a * np.arange(10) |
| 300 | + br_np = b * np.arange(10) |
| 301 | + |
| 302 | + xp_assert_equal(isclose(ar_xp, br_xp), xp.asarray(np.isclose(ar_np, br_np))) |
| 303 | + |
| 304 | + @pytest.mark.parametrize("dtype", ["float32", "int32"]) |
| 305 | + def test_broadcast(self, dtype: str, xp: ModuleType): |
| 306 | + dtype = getattr(xp, dtype) |
| 307 | + a = xp.asarray([1, 2, 3], dtype=dtype) |
| 308 | + b = xp.asarray([[1], [5]], dtype=dtype) |
| 309 | + actual = isclose(a, b) |
| 310 | + expect = xp.asarray( |
| 311 | + [[True, False, False], [False, False, False]], dtype=xp.bool |
| 312 | + ) |
| 313 | + |
| 314 | + xp_assert_equal(actual, expect) |
| 315 | + |
| 316 | + # FIXME use lazywhere to avoid warnings on inf |
| 317 | + @pytest.mark.filterwarnings("ignore:invalid value encountered") |
| 318 | + def test_some_inf(self, xp: ModuleType): |
| 319 | + a = xp.asarray([0.0, 1.0, float("inf"), float("inf"), float("inf")]) |
| 320 | + b = xp.asarray([1e-9, 1.0, float("inf"), float("-inf"), 2.0]) |
| 321 | + actual = isclose(a, b) |
| 322 | + xp_assert_equal(actual, xp.asarray([True, True, True, False, False])) |
| 323 | + |
| 324 | + def test_equal_nan(self, xp: ModuleType): |
| 325 | + a = xp.asarray([float("nan"), float("nan"), 1.0]) |
| 326 | + b = xp.asarray([float("nan"), 1.0, float("nan")]) |
| 327 | + xp_assert_equal(isclose(a, b), xp.asarray([False, False, False])) |
| 328 | + xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False])) |
| 329 | + |
| 330 | + @pytest.mark.parametrize("dtype", ["float32", "complex64", "int32"]) |
| 331 | + def test_tolerance(self, dtype: str, xp: ModuleType): |
| 332 | + dtype = getattr(xp, dtype) |
| 333 | + a = xp.asarray([100, 100], dtype=dtype) |
| 334 | + b = xp.asarray([101, 102], dtype=dtype) |
| 335 | + xp_assert_equal(isclose(a, b), xp.asarray([False, False])) |
| 336 | + xp_assert_equal(isclose(a, b, atol=1), xp.asarray([True, False])) |
| 337 | + xp_assert_equal(isclose(a, b, rtol=0.01), xp.asarray([True, False])) |
| 338 | + |
| 339 | + # Attempt to trigger division by 0 in rtol on int dtype |
| 340 | + xp_assert_equal(isclose(a, b, rtol=0), xp.asarray([False, False])) |
| 341 | + xp_assert_equal(isclose(a, b, atol=1, rtol=0), xp.asarray([True, False])) |
| 342 | + |
| 343 | + def test_very_small_numbers(self, xp: ModuleType): |
| 344 | + a = xp.asarray([1e-9, 1e-9]) |
| 345 | + b = xp.asarray([1.0001e-9, 1.00001e-9]) |
| 346 | + # Difference is below default atol |
| 347 | + xp_assert_equal(isclose(a, b), xp.asarray([True, True])) |
| 348 | + # Use only rtol |
| 349 | + xp_assert_equal(isclose(a, b, atol=0), xp.asarray([False, True])) |
| 350 | + xp_assert_equal(isclose(a, b, atol=0, rtol=0), xp.asarray([False, False])) |
| 351 | + |
| 352 | + def test_bool_dtype(self, xp: ModuleType): |
| 353 | + a = xp.asarray([False, True, False]) |
| 354 | + b = xp.asarray([True, True, False]) |
| 355 | + xp_assert_equal(isclose(a, b), xp.asarray([False, True, True])) |
| 356 | + xp_assert_equal(isclose(a, b, atol=1), xp.asarray([True, True, True])) |
| 357 | + xp_assert_equal(isclose(a, b, atol=2), xp.asarray([True, True, True])) |
| 358 | + xp_assert_equal(isclose(a, b, rtol=1), xp.asarray([True, True, True])) |
| 359 | + xp_assert_equal(isclose(a, b, rtol=2), xp.asarray([True, True, True])) |
| 360 | + |
| 361 | + # Test broadcasting |
| 362 | + xp_assert_equal( |
| 363 | + isclose(a, xp.asarray(True), atol=1), xp.asarray([True, True, True]) |
| 364 | + ) |
| 365 | + xp_assert_equal( |
| 366 | + isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True]) |
| 367 | + ) |
| 368 | + |
| 369 | + def test_none_shape(self, xp: ModuleType): |
| 370 | + a = xp.asarray([1, 5, 0]) |
| 371 | + b = xp.asarray([1, 4, 2]) |
| 372 | + b = b[a < 5] |
| 373 | + a = a[a < 5] |
| 374 | + xp_assert_equal(isclose(a, b), xp.asarray([True, False])) |
| 375 | + |
| 376 | + def test_none_shape_bool(self, xp: ModuleType): |
| 377 | + a = xp.asarray([True, True, False]) |
| 378 | + b = xp.asarray([True, False, True]) |
| 379 | + b = b[a] |
| 380 | + a = a[a] |
| 381 | + xp_assert_equal(isclose(a, b), xp.asarray([True, False])) |
| 382 | + |
| 383 | + def test_xp(self, xp: ModuleType): |
| 384 | + a = xp.asarray([0.0, 0.0]) |
| 385 | + b = xp.asarray([1e-9, 1e-4]) |
| 386 | + xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False])) |
| 387 | + |
| 388 | + |
255 | 389 | @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
|
256 | 390 | class TestKron:
|
257 | 391 | def test_basic(self, xp: ModuleType):
|
|
0 commit comments