File tree 1 file changed +6
-2
lines changed
arraycontext/impl/pytato/batched_einsum
1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -191,10 +191,14 @@ def apply_kennedy_fusion_with_batched_einsum_extension(
191
191
192
192
if insn .reduction_inames ():
193
193
einsum , _ = fnsm .get_a_matched_einsum (
194
- t_unit , insn_match = lp_match .Id (insn .id ))
194
+ t_unit , insn_match = lp_match .Id (insn .id ),
195
+ # only consider inames with same length for fusion
196
+ # => do not parametrize inames with very long loop-counts.
197
+ long_dim_length = np .inf )
195
198
einsum = fnsm .canonicalize_einsum (einsum )
196
199
subst_map = fnsm .match_t_unit_to_einsum (
197
- t_unit , einsum , insn_match = lp_match .Id (insn .id ))
200
+ t_unit , einsum , insn_match = lp_match .Id (insn .id ),
201
+ long_dim_length = np .inf )
198
202
else :
199
203
# we treat any non-reduction einsum as a copy-einsum
200
204
assignee = insn .assignee
You can’t perform that action at this time.
0 commit comments