|
| 1 | +""" |
| 2 | +Test that array_api_compat has no "hard" dependencies. |
| 3 | +
|
| 4 | +Libraries like NumPy should only be imported if a numpy array is passed to |
| 5 | +array_namespace or if array_api_compat.numpy is explicitly imported. |
| 6 | +
|
| 7 | +We have to test this in a subprocess because these libraries have already been |
| 8 | +imported from the other tests. |
| 9 | +""" |
| 10 | + |
| 11 | +import sys |
| 12 | +import subprocess |
| 13 | + |
| 14 | +from ._helpers import import_ |
| 15 | + |
| 16 | +import pytest |
| 17 | + |
| 18 | +class Array: |
| 19 | + # Dummy array namespace that doesn't depend on any array library |
| 20 | + def __array_namespace__(self, api_version=None): |
| 21 | + class Namespace: |
| 22 | + pass |
| 23 | + return Namespace() |
| 24 | + |
| 25 | +def _test_dependency(mod): |
| 26 | + assert mod not in sys.modules |
| 27 | + |
| 28 | + # Run various functions that shouldn't depend on mod and check that they |
| 29 | + # don't import it. |
| 30 | + |
| 31 | + import array_api_compat |
| 32 | + assert mod not in sys.modules |
| 33 | + |
| 34 | + a = Array() |
| 35 | + |
| 36 | + # array-api-strict is an example of an array API library that isn't |
| 37 | + # wrapped by array-api-compat. |
| 38 | + if "strict" not in mod: |
| 39 | + is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array") |
| 40 | + assert not is_mod_array(a) |
| 41 | + assert mod not in sys.modules |
| 42 | + |
| 43 | + is_array_api_obj = getattr(array_api_compat, "is_array_api_obj") |
| 44 | + assert is_array_api_obj(a) |
| 45 | + assert mod not in sys.modules |
| 46 | + |
| 47 | + array_namespace = getattr(array_api_compat, "array_namespace") |
| 48 | + array_namespace(Array()) |
| 49 | + assert mod not in sys.modules |
| 50 | + |
| 51 | +# TODO: Test that wrapper for library X doesn't depend on wrappers for library |
| 52 | +# Y (except most array libraries actually do themselves depend on numpy). |
| 53 | + |
| 54 | +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", |
| 55 | + "jax.numpy", "array_api_strict"]) |
| 56 | +def test_numpy_dependency(library): |
| 57 | + # This unfortunately won't go through any of the pytest machinery. We |
| 58 | + # reraise the exception as an AssertionError so that pytest will show it |
| 59 | + # in a semi-reasonable way |
| 60 | + |
| 61 | + # Import (in this process) to make sure 'library' is actually installed and |
| 62 | + # so that cupy can be skipped. |
| 63 | + import_(library) |
| 64 | + |
| 65 | + try: |
| 66 | + subprocess.run([sys.executable, '-c', f'''\ |
| 67 | +from tests.test_no_dependencies import _test_dependency |
| 68 | +
|
| 69 | +_test_dependency({library!r})'''], check=True, capture_output=True, encoding='utf-8') |
| 70 | + except subprocess.CalledProcessError as e: |
| 71 | + print(e.stdout, end='') |
| 72 | + raise AssertionError(e.stderr) |
0 commit comments