Skip to content

Commit a882502

Browse files
authored
Merge pull request #336 from lucascolley/runtime-xp
feat: `ARRAY_API_TESTS_MODULE` for runtime-defined xp
1 parent f7a74a6 commit a882502

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

Diff for: README.md

+6
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ You need to specify the array library to test. It can be specified via the
3636
$ export ARRAY_API_TESTS_MODULE=array_api_strict
3737
```
3838

39+
To specify a runtime-defined module, define `xp` using the `exec('...')` syntax:
40+
41+
```bash
42+
$ export ARRAY_API_TESTS_MODULE=exec('import quantity_array, numpy; xp = quantity_array.quantity_namespace(numpy)')
43+
```
44+
3945
Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`.
4046

4147
### Specifying the API version

Diff for: array_api_tests/__init__.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,27 @@
1313
# You can comment the following out and instead import the specific array module
1414
# you want to test, e.g. `import array_api_strict as xp`.
1515
if "ARRAY_API_TESTS_MODULE" in os.environ:
16-
xp_name = os.environ["ARRAY_API_TESTS_MODULE"]
17-
_module, _sub = xp_name, None
18-
if "." in xp_name:
19-
_module, _sub = xp_name.split(".", 1)
20-
xp = import_module(_module)
21-
if _sub:
22-
try:
23-
xp = getattr(xp, _sub)
24-
except AttributeError:
25-
# _sub may be a submodule that needs to be imported. WE can't
26-
# do this in every case because some array modules are not
27-
# submodules that can be imported (like mxnet.nd).
28-
xp = import_module(xp_name)
16+
env_var = os.environ["ARRAY_API_TESTS_MODULE"]
17+
if env_var.startswith("exec('") and env_var.endswith("')"):
18+
script = env_var[6:][:-2]
19+
namespace = {}
20+
exec(script, namespace)
21+
xp = namespace["xp"]
22+
xp_name = xp.__name__
23+
else:
24+
xp_name = env_var
25+
_module, _sub = xp_name, None
26+
if "." in xp_name:
27+
_module, _sub = xp_name.split(".", 1)
28+
xp = import_module(_module)
29+
if _sub:
30+
try:
31+
xp = getattr(xp, _sub)
32+
except AttributeError:
33+
# _sub may be a submodule that needs to be imported. We can't
34+
# do this in every case because some array modules are not
35+
# submodules that can be imported (like mxnet.nd).
36+
xp = import_module(xp_name)
2937
else:
3038
raise RuntimeError(
3139
"No array module specified - either edit __init__.py or set the "

0 commit comments

Comments
 (0)