Skip to content

Commit 6f3c347

Browse files
committed
Implemented BatchedEinsumArrayContext
1 parent 9bba8e4 commit 6f3c347

File tree

3 files changed

+753
-0
lines changed

3 files changed

+753
-0
lines changed

arraycontext/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .impl.jax import EagerJAXArrayContext
5454
from .impl.pyopencl import PyOpenCLArrayContext
5555
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
56+
from .impl.pytato.batched_einsum import BatchedEinsumPytatoPyOpenCLArrayContext
5657
from .impl.pytato.split_actx import SplitPytatoPyOpenCLArrayContext
5758
from .loopy import make_loopy_program
5859
# deprecated, remove in 2022.
@@ -100,6 +101,7 @@
100101

101102
"PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",
102103
"SplitPytatoPyOpenCLArrayContext",
104+
"BatchedEinsumPytatoPyOpenCLArrayContext",
103105

104106
"PytatoJAXArrayContext",
105107
"EagerJAXArrayContext",
Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
"""
2+
.. autoclass:: BatchedEinsumPytatoPyOpenCLArrayContext
3+
"""
4+
5+
__copyright__ = """
6+
Copyright (C) 2023 Kaushik Kulkarni
7+
Copyright (C) 2022 Andreas Kloeckner
8+
Copyright (C) 2022 Matthias Diener
9+
Copyright (C) 2022 Matt Smith
10+
"""
11+
12+
__license__ = """
13+
Permission is hereby granted, free of charge, to any person obtaining a copy
14+
of this software and associated documentation files (the "Software"), to deal
15+
in the Software without restriction, including without limitation the rights
16+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17+
copies of the Software, and to permit persons to whom the Software is
18+
furnished to do so, subject to the following conditions:
19+
20+
The above copyright notice and this permission notice shall be included in
21+
all copies or substantial portions of the Software.
22+
23+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
29+
THE SOFTWARE.
30+
"""
31+
32+
33+
import logging
34+
import sys
35+
from typing import TYPE_CHECKING, Any, Callable, Optional, Type
36+
from warnings import warn
37+
38+
import numpy as np
39+
40+
import loopy as lp
41+
from pytools import ProcessLogger
42+
from pytools.tag import Tag
43+
44+
from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
45+
46+
47+
logger = logging.getLogger(__name__)
48+
49+
50+
if TYPE_CHECKING or getattr(sys, "_BUILDING_SPHINX_DOCS", False):
51+
import pyopencl as cl
52+
import pytato
53+
54+
55+
class BatchedEinsumPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext):
56+
r"""
57+
.. attribute:: loop_fusion_axis_tag_t
58+
59+
A subtype of :class:`pytato.tag.Tag` that are attached to the
60+
:class:`~pytato.array.Array`\ 's axes in an expression graph. Loops that
61+
iterate over axes tagged with instances of same such tag types will form the
62+
candidate loops for Kennedy's unweighted Loop Fusion algorithm.
63+
64+
.. attribute:: fallback_to_no_fusion
65+
66+
If *True*, during the compilation of an array expression graph for which
67+
loop fusion fails (see note) transformation routines from
68+
:class:`arraycontext.SplitPytatoPyOpenCLArrayContext` are invoked.
69+
70+
.. attribute:: feinsum_db
71+
72+
An instance of :class:`str` corresponding to the database of tuned batched
73+
einsums. If *None*, then a static transformation strategy is applied to the
74+
batched einsums kernels.
75+
76+
.. attribute:: log_loopy_statistics
77+
78+
If *True*, statistics of compiled :class:`loopy.TranslationUnit` will be
79+
logged. If enable, we log the FLOPS and global memory access footprint for
80+
each of the programs. If *False*, nothing is done.
81+
82+
.. note::
83+
84+
The conditions under which we fallback (or raise) are:
85+
86+
#. There exists an array that is to be materialized but at least one of its
87+
axes is not tagged with tags of :attr:`loop_fusion_axis_tag_t`.
88+
"""
89+
def __init__(
90+
self,
91+
queue: "cl.CommandQueue", allocator=None,
92+
*,
93+
loop_fusion_axis_tag_t: Type[Tag],
94+
fallback_to_no_fusion: bool = True,
95+
assume_all_indirection_maps_as_non_negative: bool = False,
96+
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None,
97+
feinsum_db: Optional[str] = None,
98+
log_loopy_statistics: bool = False,
99+
fused_loop_name_prefix_getter: Optional[Callable[[Tag], str]] = None,
100+
) -> None:
101+
super().__init__(queue,
102+
allocator,
103+
compile_trace_callback=compile_trace_callback)
104+
105+
self.loop_fusion_axis_tag_t = loop_fusion_axis_tag_t
106+
self.fallback_to_no_fusion = fallback_to_no_fusion
107+
self.feinsum_db = feinsum_db
108+
self.assume_all_indirection_maps_as_non_negative = (
109+
assume_all_indirection_maps_as_non_negative)
110+
self.log_loopy_statistics = log_loopy_statistics
111+
if fused_loop_name_prefix_getter:
112+
self.fused_loop_name_prefix_getter = fused_loop_name_prefix_getter
113+
else:
114+
self.fused_loop_name_prefix_getter = lambda tag_t: "ifused"
115+
116+
def transform_dag(self,
117+
dag: "pytato.DictOfNamedArrays") -> "pytato.DictOfNamedArrays":
118+
import pytato as pt
119+
120+
from .utils import (
121+
_make_passthrough_arg, get_indirection_maps,
122+
get_inputs_and_outputs_of_reduction_nodes)
123+
from arraycontext.impl.pytato.split_actx.utils import (
124+
get_inputs_and_outputs_of_einsum)
125+
126+
# Step 1. Collapse equivalent nodes in DAG.
127+
# -----------------------------------------
128+
# type-ignore-reason: mypy is right pytato provides imprecise types.
129+
dag = pt.transform.deduplicate_data_wrappers(dag) # type: ignore[assignment]
130+
131+
# Step 2. Materialize einsum/reduction outputs.
132+
# ---------------------------------------------
133+
_, einsum_outputs = get_inputs_and_outputs_of_einsum(dag)
134+
_, reduction_outputs = get_inputs_and_outputs_of_reduction_nodes(dag)
135+
136+
def materialize_all_einsums_or_reduces(expr):
137+
if (expr in einsum_outputs
138+
or expr in reduction_outputs):
139+
return expr.tagged(pt.tags.ImplStored())
140+
else:
141+
return expr
142+
143+
# type-ignore-reason: mypy is right pytato provides imprecise types.
144+
dag = pt.transform.map_and_copy(dag, # type: ignore[assignment]
145+
materialize_all_einsums_or_reduces)
146+
147+
# Step 3. Materialize with MPMS
148+
# -----------------------------
149+
dag = pt.transform.materialize_with_mpms(dag)
150+
151+
# Step 4. Mark all indirection maps as non-negative
152+
# -------------------------------------------------
153+
if self.assume_all_indirection_maps_as_non_negative:
154+
indirection_maps = get_indirection_maps(dag)
155+
156+
def tag_indices_as_non_negative(ary):
157+
if ary in indirection_maps:
158+
return ary.tagged(pt.tags.AssumeNonNegative())
159+
else:
160+
return ary
161+
162+
# type-ignore-reason: mypy is right pytato provides imprecise types.
163+
dag = pt.transform.map_and_copy(dag, # type: ignore[assignment]
164+
tag_indices_as_non_negative)
165+
166+
# Step 5. Get rid of broadcasts in einsum expressions (helps feinsum)
167+
# -------------------------------------------------------------------
168+
dag = pt.rewrite_einsums_with_no_broadcasts(dag)
169+
170+
# Step 6. Infer axis tags
171+
# -----------------------
172+
# type-ignore-reason: mypy is right pytato provides imprecise types.
173+
dag = pt.unify_axes_tags(dag) # type: ignore[assignment]
174+
175+
# Step 7. Make all pt.einsum/pt.reduction inputs as substitutions
176+
# ---------------------------------------------------------------
177+
def implement_einsum_reduction_inputs_as_substs(expr):
178+
from immutables import Map
179+
180+
from pytato.target.loopy import ImplSubstitution
181+
if isinstance(expr, pt.Einsum):
182+
# make the arguments passthrough to make use of already stored
183+
# values.
184+
# pylint and 'attrs' have poor compatibility
185+
# pylint: disable=too-many-function-args,redundant-keyword-arg
186+
# pylint: disable=unexpected-keyword-arg
187+
return pt.Einsum(
188+
expr.access_descriptors,
189+
tuple(_make_passthrough_arg(arg, ImplSubstitution())
190+
for arg in expr.args),
191+
expr.redn_axis_to_redn_descr,
192+
expr.index_to_access_descr,
193+
tags=expr.tags,
194+
axes=expr.axes,
195+
)
196+
elif isinstance(expr, pt.IndexLambda) and expr.var_to_reduction_descr:
197+
# make the arguments passthrough to make use of already stored
198+
# values.
199+
# pylint: disable=too-many-function-args,redundant-keyword-arg
200+
# pylint: disable=unexpected-keyword-arg
201+
return pt.IndexLambda(
202+
expr.expr,
203+
expr.shape,
204+
expr.dtype,
205+
Map({name: _make_passthrough_arg(bnd, ImplSubstitution())
206+
for name, bnd in expr.bindings.items()}),
207+
expr.var_to_reduction_descr,
208+
tags=expr.tags,
209+
axes=expr.axes,
210+
)
211+
else:
212+
return expr
213+
214+
# type-ignore-reason: mypy is right pytato provides imprecise types.
215+
dag = pt.transform.map_and_copy(dag, # type: ignore[assignment]
216+
implement_einsum_reduction_inputs_as_substs)
217+
218+
return dag
219+
220+
def transform_loopy_program(self,
221+
t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
222+
knl_name = t_unit.default_entrypoint.name
223+
224+
logger.info(f"[{self.__class__}.transform_loopy_program]:"
225+
f" Transforming kernel '{knl_name}' with"
226+
f" {len(t_unit.default_entrypoint.instructions)} statements.")
227+
228+
# Step 0. Fallback if cannot t_unit cannot be transformed
229+
# -------------------------------------------------------
230+
for iname in t_unit.default_entrypoint.all_inames():
231+
if not t_unit.default_entrypoint.iname_tags_of_type(
232+
iname, self.loop_fusion_axis_tag_t):
233+
if self.fallback_to_no_fusion:
234+
warn(f"[{knl_name}]: Falling back to a slower transformation"
235+
" strategy as some loops are uninferred which mesh entity"
236+
" they belong to.",
237+
stacklevel=2)
238+
from arraycontext.impl.pytato.split_actx import (
239+
SplitPytatoPyOpenCLArrayContext)
240+
241+
# type-ignore-reason: mypy is right, we are passing incorrect
242+
# types, but knowing the implementation of
243+
# SplitPytatoPyOpenCLArrayContext this should be fine.
244+
return SplitPytatoPyOpenCLArrayContext.transform_loopy_program(
245+
self, t_unit) # type: ignore[arg-type]
246+
else:
247+
raise RuntimeError(f"Iname '{iname}' is not tagged with tags"
248+
f" of type '{self.loop_fusion_axis_tag_t}'"
249+
" => Not allowed since Kennedy's Loop fusion"
250+
" cannot be applied.")
251+
252+
# Step 0.5. Make offsets as 0. (FIXME: move this to loopy knl invocation)
253+
# -----------------------------------------------------------------------
254+
knl = t_unit.default_entrypoint
255+
knl = knl.copy(args=[arg.copy(offset=0) for arg in knl.args])
256+
t_unit = t_unit.with_kernel(knl)
257+
del knl
258+
259+
# Step 1. Fuse loops indexing over the same tag
260+
# ---------------------------------------------
261+
with ProcessLogger(logger, f"[{knl_name}] Loop Fusion"):
262+
from .utils import apply_kennedy_fusion_with_batched_einsum_extension
263+
t_unit = apply_kennedy_fusion_with_batched_einsum_extension(
264+
t_unit, self.loop_fusion_axis_tag_t,
265+
self.fused_loop_name_prefix_getter)
266+
267+
# Step 2. Combine the domains of individual loop nests into individual
268+
# BasicSets
269+
# --------------------------------------------------------------------
270+
from .utils import combine_domains_of_perfect_loop_nests
271+
t_unit = combine_domains_of_perfect_loop_nests(t_unit)
272+
273+
# Step 3. Remove dead temporaries
274+
# -------------------------------
275+
from .utils import remove_dead_temporaries
276+
t_unit = remove_dead_temporaries(t_unit)
277+
278+
# Step 4. Contract arrays
279+
# -----------------------
280+
with ProcessLogger(logger, f"[{knl_name}] Array Contraction"):
281+
from .utils import contract_arrays
282+
t_unit = contract_arrays(t_unit)
283+
284+
# Step 5. Collect statistics
285+
# --------------------------
286+
287+
# {{{ compute stats
288+
289+
if self.log_loopy_statistics:
290+
291+
with ProcessLogger(logger, f"[{knl_name}] Count kernel metrics"):
292+
from loopy.kernel.array import ArrayBase
293+
from pytools import product
294+
knl = t_unit.default_entrypoint
295+
knl = knl.copy(
296+
silenced_warnings=(knl.silenced_warnings
297+
+ ["insn_count_subgroups_upper_bound",
298+
"summing_if_branches_ops"]))
299+
300+
t_unit = t_unit.with_kernel(knl)
301+
del knl
302+
303+
op_map = lp.get_op_map(t_unit, subgroup_size=32)
304+
305+
c64_ops = {op_type: (op_map.filter_by(dtype=[np.complex64],
306+
name=op_type,
307+
kernel_name=knl_name)
308+
.eval_and_sum({}))
309+
for op_type in ["add", "mul", "div"]}
310+
c128_ops = {op_type: (op_map.filter_by(dtype=[np.complex128],
311+
name=op_type,
312+
kernel_name=knl_name)
313+
.eval_and_sum({}))
314+
for op_type in ["add", "mul", "div"]}
315+
f32_ops = ((op_map.filter_by(dtype=[np.float32],
316+
kernel_name=knl_name)
317+
.eval_and_sum({}))
318+
+ (2 * c64_ops["add"]
319+
+ 6 * c64_ops["mul"]
320+
+ (6 + 3 + 2) * c64_ops["div"]))
321+
f64_ops = ((op_map.filter_by(dtype=[np.float64],
322+
kernel_name="_pt_kernel")
323+
.eval_and_sum({}))
324+
+ (2 * c128_ops["add"]
325+
+ 6 * c128_ops["mul"]
326+
+ (6 + 3 + 2) * c128_ops["div"]))
327+
328+
# {{{ footprint gathering
329+
330+
nfootprint_bytes = 0
331+
332+
for ary in knl.args:
333+
if (isinstance(ary, ArrayBase)
334+
and ary.address_space == lp.AddressSpace.GLOBAL):
335+
nfootprint_bytes += (product(ary.shape)
336+
* ary.dtype.itemsize)
337+
338+
for ary in knl.temporary_variables.values():
339+
if ary.address_space == lp.AddressSpace.GLOBAL:
340+
# global temps would be written once and read once
341+
nfootprint_bytes += (2 * product(ary.shape)
342+
* ary.dtype.itemsize)
343+
344+
# }}}
345+
346+
if f32_ops:
347+
logger.info(f"Single-prec. GFlOps: {f32_ops * 1e-9}")
348+
if f64_ops:
349+
logger.info(f"Double-prec. GFlOps: {f64_ops * 1e-9}")
350+
logger.info(f"Footprint GBs: {nfootprint_bytes * 1e-9}")
351+
352+
# }}}
353+
354+
# Step 6. Draw kernel boundaries between batched einsum kernels
355+
# -------------------------------------------------------------
356+
from arraycontext.impl.pytato.split_actx.utils import (
357+
add_gbarrier_between_disjoint_loop_nests)
358+
359+
t_unit = add_gbarrier_between_disjoint_loop_nests(t_unit)
360+
361+
# Step 7. Macro-kernel optimizations
362+
# ----------------------------------
363+
if self.feinsum_db:
364+
raise NotImplementedError
365+
else:
366+
from arraycontext.impl.pytato.split_actx.utils import (
367+
split_iteration_domain_across_work_items)
368+
t_unit = split_iteration_domain_across_work_items(t_unit,
369+
self.queue.device)
370+
371+
# Step 8. Alias global temporaries with disjoint live intervals
372+
# -------------------------------------------------------------
373+
from arraycontext.impl.pytato.split_actx.utils import (
374+
alias_global_temporaries)
375+
t_unit = alias_global_temporaries(t_unit)
376+
377+
return t_unit
378+
379+
# vim: fdm=marker

0 commit comments

Comments
 (0)