Skip to content

Commit 8211a9c

Browse files
committed
Test BatchedEinsumPytatoPyOpenCLArrayContext
1 parent b22feb6 commit 8211a9c

File tree

1 file changed

+268
-0
lines changed

1 file changed

+268
-0
lines changed

Diff for: test/test_batched_einsum_actx.py

+268
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
import pytest
2+
3+
4+
try:
5+
import feinsum # noqa: F401
6+
except ModuleNotFoundError:
7+
pytest.skip(reason="BatchedEinsumActx imposes feinsum as a hard dep.",
8+
allow_module_level=True)
9+
10+
try:
11+
from loopy import get_kennedy_unweighted_fusion_candidates # noqa: F401
12+
from loopy import rename_inames_in_batch # noqa: F401
13+
except ImportError:
14+
pytest.skip(reason="BatchedEinsumActx imposes loop-fusion support in "
15+
"loopy as a hard dep.", allow_module_level=True)
16+
17+
from dataclasses import dataclass
18+
19+
import numpy as np
20+
21+
from pytools.tag import UniqueTag
22+
23+
from arraycontext import (
24+
BatchedEinsumPytatoPyOpenCLArrayContext as BaseBatchedEinsumArrayContext,
25+
PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, tag_axes)
26+
from arraycontext.pytest import (
27+
_PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass,
28+
_PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory,
29+
_PytestSplitPytatoPyOpenCLArrayContextFactory,
30+
pytest_generate_tests_for_array_contexts)
31+
32+
33+
# {{{ axes tag types for image processing
34+
35+
class AxisTagsForTesting(UniqueTag):
36+
pass
37+
38+
39+
class ImageDimensionTag(AxisTagsForTesting):
40+
"""
41+
An abstract tag type that is tagged to an array's axis indexing along an image's
42+
axis.
43+
"""
44+
45+
46+
class XDimension(ImageDimensionTag):
47+
"""
48+
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the
49+
x-dimension of an image.
50+
"""
51+
52+
53+
class YDimension(ImageDimensionTag):
54+
"""
55+
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the
56+
y-dimension of an image.
57+
"""
58+
59+
60+
class ChannelDimension(ImageDimensionTag):
61+
"""
62+
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the
63+
channels of an image.
64+
"""
65+
66+
# }}}
67+
68+
69+
# {{{ generic axes tags
70+
71+
@dataclass(frozen=True)
72+
class NamedAxis(AxisTagsForTesting):
73+
name: str
74+
75+
# }}}
76+
77+
78+
# {{{ array context fixture
79+
80+
class BatchedEinsumPytatoPyOpenCLArrayContext(
81+
BaseBatchedEinsumArrayContext):
82+
def __init__(self, queue, allocator=None):
83+
super().__init__(queue, allocator,
84+
fallback_to_no_fusion=False,
85+
loop_fusion_axis_tag_t=AxisTagsForTesting)
86+
87+
88+
class _PyOpenCLArrayContextForTests(PyOpenCLArrayContext):
89+
"""Like :class:`PyOpenCLArrayContext`, but applies no program transformations
90+
whatsoever. Only to be used for testing internal to :mod:`arraycontext`.
91+
"""
92+
93+
def transform_loopy_program(self, t_unit):
94+
return t_unit
95+
96+
97+
class _PytatoPyOpenCLArrayContextForTests(PytatoPyOpenCLArrayContext):
98+
"""Like :class:`PytatoPyOpenCLArrayContext`, but applies no program
99+
transformations whatsoever. Only to be used for testing internal to
100+
:mod:`arraycontext`.
101+
"""
102+
103+
def transform_loopy_program(self, t_unit):
104+
return t_unit
105+
106+
107+
class _PytatoPyOpenCLArrayContextForTestsFactory(
108+
_PytestPytatoPyOpenCLArrayContextFactory):
109+
actx_class = _PytatoPyOpenCLArrayContextForTests
110+
111+
112+
class _PyOpenCLArrayContextForTestsFactoryWithHostScalars(
113+
_PytestPyOpenCLArrayContextFactoryWithClass):
114+
force_device_scalars = True
115+
actx_class = _PyOpenCLArrayContextForTests
116+
117+
118+
class _PytestBatchedEinsumPytatoPyOpenCLArrayContextFactory(
119+
_PytestPytatoPyOpenCLArrayContextFactory):
120+
@property
121+
def actx_class(self):
122+
return BatchedEinsumPytatoPyOpenCLArrayContext
123+
124+
125+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
126+
_PyOpenCLArrayContextForTestsFactoryWithHostScalars,
127+
_PytatoPyOpenCLArrayContextForTestsFactory,
128+
_PytestEagerJaxArrayContextFactory,
129+
_PytestPytatoJaxArrayContextFactory,
130+
_PytestSplitPytatoPyOpenCLArrayContextFactory,
131+
_PytestBatchedEinsumPytatoPyOpenCLArrayContextFactory,
132+
])
133+
134+
# }}}
135+
136+
137+
def test_simple_add(actx_factory):
138+
# Lesson 01 of Halide Tutorial
139+
actx = actx_factory()
140+
141+
rng = np.random.default_rng(0)
142+
a_np = rng.random((800, 600))
143+
b_np = rng.random((800, 600))
144+
a = actx.from_numpy(a_np)
145+
b = actx.from_numpy(b_np)
146+
147+
a = tag_axes(actx, {0: XDimension(), 1: YDimension()}, a)
148+
b = tag_axes(actx, {0: XDimension(), 1: YDimension()}, b)
149+
150+
out = actx.to_numpy(a + b)
151+
ref_out = a_np + b_np
152+
153+
np.testing.assert_allclose(out, ref_out)
154+
155+
156+
def test_brighten_image(actx_factory):
157+
# Lesson 02 of Halide Tutorial
158+
actx = actx_factory()
159+
160+
rng = np.random.default_rng(0)
161+
162+
img_np = 255*rng.random((800, 600, 3), dtype=np.float32)
163+
164+
img = actx.from_numpy(img_np)
165+
img = tag_axes(actx,
166+
{0: XDimension(), 1: YDimension(), 2: ChannelDimension()},
167+
img)
168+
169+
brightened_img = 1.5*img
170+
clamped_brightened_img = actx.np.minimum(brightened_img, np.float32(255))
171+
172+
out = actx.to_numpy(clamped_brightened_img)
173+
ref_out = np.minimum(1.5*img_np, np.float32(255))
174+
175+
np.testing.assert_allclose(out, ref_out)
176+
177+
178+
def test_simple_einsum(actx_factory):
179+
actx = actx_factory()
180+
181+
rng = np.random.default_rng()
182+
183+
a_np = rng.random((10, 4))
184+
a = actx.from_numpy(a_np)
185+
a = tag_axes(actx,
186+
{0: XDimension(), 1: YDimension()}, a)
187+
188+
out1 = actx.einsum("ij,ij->i", a, a+1)
189+
out2 = actx.einsum("ij,ij->i", 2*a, 3*a+7)
190+
191+
ref_out = (np.einsum("ij,ij->i", a_np, a_np + 1)
192+
+ np.einsum("ij,ij->i", 2*a_np, 3*a_np+7))
193+
out = actx.to_numpy(out1 + out2)
194+
195+
np.testing.assert_allclose(ref_out, out)
196+
197+
198+
def test_nested_einsum(actx_factory):
199+
actx = actx_factory()
200+
201+
rng = np.random.default_rng()
202+
203+
a_np = rng.random((10, 4))
204+
205+
# {{{ compute out
206+
207+
a = actx.from_numpy(a_np)
208+
a = tag_axes(actx,
209+
{0: XDimension(), 1: YDimension()}, a)
210+
b = a + 1
211+
212+
out1 = actx.einsum("ij,ij->i", a, b)
213+
out2 = actx.einsum("ij,ij->i", 2*a, 3*a+7)
214+
out3 = actx.einsum("ij,i->i", 3*b, 2*out1)
215+
216+
out = actx.to_numpy(out1 + out2 + out3)
217+
218+
# }}}
219+
220+
# {{{ compute ref_out
221+
222+
b_np = a_np + 1
223+
out1_np = np.einsum("ij,ij->i", a_np, a_np+1)
224+
out2_np = np.einsum("ij,ij->i", 2*a_np, 3*a_np+7)
225+
out3_np = np.einsum("ij,i->i", 3*b_np, 2*out1_np)
226+
ref_out = out1_np + out2_np + out3_np
227+
228+
# }}}
229+
230+
np.testing.assert_allclose(ref_out, out)
231+
232+
233+
def test_dg_3d_divergence(actx_factory):
234+
actx = actx_factory()
235+
rng = np.random.default_rng(42)
236+
n_el = 1000
237+
n_dof = 35
238+
239+
ax_np, ay_np, az_np = rng.random((3, n_el, n_dof))
240+
jac_np = rng.random((3, 3, n_el))
241+
mat_np = rng.random((3, n_dof, n_dof))
242+
243+
ax, ay, az = (actx.from_numpy(ax_np),
244+
actx.from_numpy(ay_np),
245+
actx.from_numpy(az_np))
246+
jac = actx.from_numpy(jac_np)
247+
jac = tag_axes(actx, {0: NamedAxis("x"),
248+
1: NamedAxis("r"),
249+
2: NamedAxis("e")}, jac)
250+
mat = actx.from_numpy(mat_np)
251+
mat = tag_axes(actx, {0: NamedAxis("r"),
252+
1: NamedAxis("i"),
253+
2: NamedAxis("j")}, mat)
254+
255+
out = 2*actx.einsum(
256+
"xre,rij,xej->ei",
257+
jac, mat, actx.np.stack([3*actx.np.sin(ax) + 4*actx.np.cos(ax),
258+
12*actx.np.exp(ay) + 5*actx.np.cos(ay),
259+
8*az]))
260+
ref_out = 2*np.einsum(
261+
"xre,rij,xej->ei",
262+
jac_np, mat_np, np.stack([3*np.sin(ax_np) + 4*np.cos(ax_np),
263+
12*np.exp(ay_np) + 5*np.cos(ay_np),
264+
8*az_np]))
265+
266+
np.testing.assert_allclose(ref_out, actx.to_numpy(out))
267+
268+
# vim: fdm=marker

0 commit comments

Comments
 (0)