1
1
import pytest
2
2
from modules .jax .testcontainers .jax_cuda import JAXContainer
3
3
4
- def test_jax_container ():
5
- with JAXContainer () as jax_container :
6
- jax_container .connect ()
7
-
8
- # Test running a simple JAX computation
9
- result = jax_container .run_jax_command ("import jax; print(jax.numpy.add(1, 1))" )
10
- assert "2" in result .output .decode ()
4
+ @pytest .fixture (scope = "module" )
5
+ def jax_container ():
6
+ with JAXContainer () as container :
7
+ container .connect ()
8
+ yield container
11
9
12
- def test_jax_container_gpu_support ():
13
- with JAXContainer () as jax_container :
14
- jax_container .connect ()
15
-
16
- # Test GPU availability
17
- result = jax_container .run_jax_command (
18
- "import jax; print(jax.devices())"
19
- )
20
- assert "gpu" in result .output .decode ().lower ()
10
+ def test_jax_container_basic_computation (jax_container ):
11
+ result = jax_container .run_jax_command ("import jax; print(jax.numpy.add(1, 1))" )
12
+ assert "2" in result .output .decode (), "Basic JAX computation failed"
21
13
22
- def test_jax_container_jupyter ():
23
- with JAXContainer () as jax_container :
24
- jax_container .connect ()
25
-
26
- jupyter_url = jax_container .get_jupyter_url ()
27
- assert jupyter_url .startswith ("http://" )
28
- assert ":8888" in jupyter_url
14
+ def test_jax_container_version (jax_container ):
15
+ result = jax_container .run_jax_command ("import jax; print(jax.__version__)" )
16
+ assert result .exit_code == 0 , "Failed to get JAX version"
17
+ assert result .output .decode ().strip (), "JAX version is empty"
18
+
19
+ def test_jax_container_gpu_support (jax_container ):
20
+ result = jax_container .run_jax_command (
21
+ "import jax; devices = jax.devices(); "
22
+ "print(any(dev.platform == 'gpu' for dev in devices))"
23
+ )
24
+ assert "True" in result .output .decode (), "No GPU device found"
25
+
26
+ def test_jax_container_matrix_multiplication (jax_container ):
27
+ command = """
28
+ import jax
29
+ import jax.numpy as jnp
30
+ x = jnp.array([[1, 2], [3, 4]])
31
+ y = jnp.array([[5, 6], [7, 8]])
32
+ result = jnp.dot(x, y)
33
+ print(result)
34
+ """
35
+ result = jax_container .run_jax_command (command )
36
+ assert "[[19 22]\n [43 50]]" in result .output .decode (), "Matrix multiplication failed"
37
+
38
+ def test_jax_container_custom_image ():
39
+ custom_image = "nvcr.io/nvidia/jax:23.09-py3"
40
+ with JAXContainer (image = custom_image ) as container :
41
+ container .connect ()
42
+ result = container .run_jax_command ("import jax; print(jax.__version__)" )
43
+ assert result .exit_code == 0 , f"Failed to run JAX with custom image { custom_image } "
0 commit comments