Skip to content

Commit 8dabef0

Browse files
committed
add test for jaxcontainer
1 parent 54f842d commit 8dabef0

File tree

1 file changed

+38
-23
lines changed

1 file changed

+38
-23
lines changed

modules/jax/tests/test_jax.py

+38-23
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,43 @@
11
import pytest
22
from modules.jax.testcontainers.jax_cuda import JAXContainer
33

4-
def test_jax_container():
5-
with JAXContainer() as jax_container:
6-
jax_container.connect()
7-
8-
# Test running a simple JAX computation
9-
result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))")
10-
assert "2" in result.output.decode()
4+
@pytest.fixture(scope="module")
5+
def jax_container():
6+
with JAXContainer() as container:
7+
container.connect()
8+
yield container
119

12-
def test_jax_container_gpu_support():
13-
with JAXContainer() as jax_container:
14-
jax_container.connect()
15-
16-
# Test GPU availability
17-
result = jax_container.run_jax_command(
18-
"import jax; print(jax.devices())"
19-
)
20-
assert "gpu" in result.output.decode().lower()
10+
def test_jax_container_basic_computation(jax_container):
11+
result = jax_container.run_jax_command("import jax; print(jax.numpy.add(1, 1))")
12+
assert "2" in result.output.decode(), "Basic JAX computation failed"
2113

22-
def test_jax_container_jupyter():
23-
with JAXContainer() as jax_container:
24-
jax_container.connect()
25-
26-
jupyter_url = jax_container.get_jupyter_url()
27-
assert jupyter_url.startswith("http://")
28-
assert ":8888" in jupyter_url
14+
def test_jax_container_version(jax_container):
15+
result = jax_container.run_jax_command("import jax; print(jax.__version__)")
16+
assert result.exit_code == 0, "Failed to get JAX version"
17+
assert result.output.decode().strip(), "JAX version is empty"
18+
19+
def test_jax_container_gpu_support(jax_container):
20+
result = jax_container.run_jax_command(
21+
"import jax; devices = jax.devices(); "
22+
"print(any(dev.platform == 'gpu' for dev in devices))"
23+
)
24+
assert "True" in result.output.decode(), "No GPU device found"
25+
26+
def test_jax_container_matrix_multiplication(jax_container):
27+
command = """
28+
import jax
29+
import jax.numpy as jnp
30+
x = jnp.array([[1, 2], [3, 4]])
31+
y = jnp.array([[5, 6], [7, 8]])
32+
result = jnp.dot(x, y)
33+
print(result)
34+
"""
35+
result = jax_container.run_jax_command(command)
36+
assert "[[19 22]\n [43 50]]" in result.output.decode(), "Matrix multiplication failed"
37+
38+
def test_jax_container_custom_image():
39+
custom_image = "nvcr.io/nvidia/jax:23.09-py3"
40+
with JAXContainer(image=custom_image) as container:
41+
container.connect()
42+
result = container.run_jax_command("import jax; print(jax.__version__)")
43+
assert result.exit_code == 0, f"Failed to run JAX with custom image {custom_image}"

0 commit comments

Comments
 (0)