Skip to content

Commit caaa471

Browse files
committed
fixup! Implemented BatchedEinsumArrayContext
avoid loop-fusion errors associated with saturations of long dimensions
1 parent e4cf24b commit caaa471

File tree

1 file changed

+6
-2
lines changed
  • arraycontext/impl/pytato/batched_einsum

1 file changed

+6
-2
lines changed

arraycontext/impl/pytato/batched_einsum/utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,14 @@ def apply_kennedy_fusion_with_batched_einsum_extension(
191191

192192
if insn.reduction_inames():
193193
einsum, _ = fnsm.get_a_matched_einsum(
194-
t_unit, insn_match=lp_match.Id(insn.id))
194+
t_unit, insn_match=lp_match.Id(insn.id),
195+
# only consider inames with same length for fusion
196+
# => do not parametrize inames with very long loop-counts.
197+
long_dim_length=np.inf)
195198
einsum = fnsm.canonicalize_einsum(einsum)
196199
subst_map = fnsm.match_t_unit_to_einsum(
197-
t_unit, einsum, insn_match=lp_match.Id(insn.id))
200+
t_unit, einsum, insn_match=lp_match.Id(insn.id),
201+
long_dim_length=np.inf)
198202
else:
199203
# we treat any non-reduction einsum as a copy-einsum
200204
assignee = insn.assignee

0 commit comments

Comments
 (0)