Skip to content

Commit e4cf24b

Browse files
committed
fixup! Test BatchedEinsumPytatoPyOpenCLArrayContext
Adds a failing test for the dimension mismatch error.
1 parent 8211a9c commit e4cf24b

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

test/test_batched_einsum_actx.py

+29
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020

21+
from pytools.obj_array import make_obj_array
2122
from pytools.tag import UniqueTag
2223

2324
from arraycontext import (
@@ -265,4 +266,32 @@ def test_dg_3d_divergence(actx_factory):
265266

266267
np.testing.assert_allclose(ref_out, actx.to_numpy(out))
267268

269+
270+
def test_multiple_large_sized_outputs(actx_factory):
271+
actx = actx_factory()
272+
rng = np.random.default_rng(0)
273+
n1 = 1_000_000
274+
n2 = 2_000_000
275+
276+
x1_np = rng.random((n1, 1))
277+
x2_np = rng.random((n2, 1))
278+
279+
x1 = actx.from_numpy(x1_np)
280+
x2 = actx.from_numpy(x2_np)
281+
282+
x1 = tag_axes(actx, {0: NamedAxis("e"),
283+
1: NamedAxis("i")},
284+
x1)
285+
x2 = tag_axes(actx, {0: NamedAxis("e"),
286+
1: NamedAxis("i")},
287+
x2)
288+
289+
out = make_obj_array([actx.einsum("ij->i", 3 * x1),
290+
actx.einsum("ij->i", 4 * x2)])
291+
ref_out = make_obj_array([np.einsum("ij->i", 3 * x1_np),
292+
np.einsum("ij->i", 4 * x2_np)])
293+
294+
np.testing.assert_allclose(ref_out[0], actx.to_numpy(out)[0])
295+
np.testing.assert_allclose(ref_out[1], actx.to_numpy(out)[1])
296+
268297
# vim: fdm=marker

0 commit comments

Comments
 (0)