2
2
import sys
3
3
import warnings
4
4
5
- import jax
6
5
import numpy as np
7
6
import pytest
8
- import torch
9
7
10
8
import array_api_compat
11
9
from array_api_compat import array_namespace
@@ -76,6 +74,7 @@ def test_array_namespace(library, api_version, use_compat):
76
74
subprocess .run ([sys .executable , "-c" , code ], check = True )
77
75
78
76
def test_jax_zero_gradient ():
77
+ jax = import_ ("jax" )
79
78
jx = jax .numpy .arange (4 )
80
79
jax_zero = jax .vmap (jax .grad (jax .numpy .float32 , allow_int = True ))(jx )
81
80
assert array_namespace (jax_zero ) is array_namespace (jx )
@@ -89,11 +88,13 @@ def test_array_namespace_errors():
89
88
pytest .raises (TypeError , lambda : array_namespace (x , (x , x )))
90
89
91
90
def test_array_namespace_errors_torch ():
91
+ torch = import_ ("torch" )
92
92
y = torch .asarray ([1 , 2 ])
93
93
x = np .asarray ([1 , 2 ])
94
94
pytest .raises (TypeError , lambda : array_namespace (x , y ))
95
95
96
96
def test_api_version_torch ():
97
+ torch = import_ ("torch" )
97
98
x = torch .asarray ([1 , 2 ])
98
99
torch_ = import_ ("torch" , wrapper = True )
99
100
assert array_namespace (x , api_version = "2023.12" ) == torch_
@@ -118,6 +119,7 @@ def test_get_namespace():
118
119
assert array_api_compat .get_namespace is array_namespace
119
120
120
121
def test_python_scalars ():
122
+ torch = import_ ("torch" )
121
123
a = torch .asarray ([1 , 2 ])
122
124
xp = import_ ("torch" , wrapper = True )
123
125
0 commit comments