Skip to content

Commit a24f534

Browse files
committed
Fix bug in ScanSaveMem with broadcasted initial value
1 parent c822a8e commit a24f534

File tree

2 files changed

+86
-70
lines changed

2 files changed

+86
-70
lines changed

pytensor/scan/rewriting.py

+67-63
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@
5858
from pytensor.tensor.elemwise import DimShuffle, Elemwise
5959
from pytensor.tensor.exceptions import NotScalarConstantError
6060
from pytensor.tensor.math import Dot, dot, maximum, minimum
61-
from pytensor.tensor.rewriting.basic import constant_folding, local_useless_switch
61+
from pytensor.tensor.rewriting.basic import (
62+
broadcasted_by,
63+
constant_folding,
64+
local_useless_switch,
65+
)
6266
from pytensor.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
6367
from pytensor.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
6468
from pytensor.tensor.shape import shape
@@ -1182,6 +1186,44 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
11821186
return subtensor_merge_replacements
11831187

11841188

1189+
def _is_default_scan_buffer(x: TensorVariable) -> bool:
1190+
node = x.owner
1191+
1192+
if node is None:
1193+
return False
1194+
1195+
op = node.op
1196+
if not (
1197+
isinstance(op, IncSubtensor)
1198+
and op.set_instead_of_inc
1199+
and op.idx_list == [slice(None, ps.int64)]
1200+
):
1201+
return False
1202+
1203+
x, y, *_ = node.inputs
1204+
if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)):
1205+
return False
1206+
1207+
# The value may have been broadcast to fill in the initial taps.
1208+
# If the user specified outputs as:
1209+
# x = scalar(); init = alloc(x, 2);
1210+
# outputs_info=[init, taps=(-2, -1)]
1211+
# Scan will generate an initial buffer that looks like
1212+
# alloc_empty(2 + nsteps)[:2].set(alloc(x, 2))
1213+
# PyTensor will then rewrite it as:
1214+
# alloc_empty(2 + nsteps)[:2].set(x)
1215+
# When the initial value (x) is being broadcast by the set_subtensor
1216+
# we can't recreate a newly sized buffer working with x alone
1217+
# We want to check that:
1218+
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
1219+
# But due to laziness we use the slightly more conservative check:
1220+
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
1221+
if broadcasted_by(y, x):
1222+
return False
1223+
1224+
return True
1225+
1226+
11851227
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
11861228
r"""Graph optimizer that reduces scan memory consumption.
11871229
@@ -1520,51 +1562,28 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
15201562

15211563
# 3.2 check orphane outputs to see if we can eliminate any
15221564
required, not_required = scan_can_remove_outs(node.op, orphane_outs)
1523-
# 3.3. compose replace pairs for those nodes that need not
1524-
# to store everything in memory ( or ar orphane and required
1525-
# by the inner function .. )
1565+
1566+
# 3.3. compose replace pairs for those nodes that need not store everything in memory
1567+
# (or ar orphan but required by the inner function)
15261568
replaced_outs = []
15271569
offset = 1 + op_info.n_seqs + op_info.n_mit_mot
1528-
for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]):
1570+
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
15291571
i = idx + op_info.n_mit_mot
1530-
if not (isinstance(_val, int) and _val <= 0 and i not in required):
1531-
if idx + op_info.n_mit_mot in required:
1532-
val = 1
1533-
else:
1534-
val = _val
1572+
if not (isinstance(val, int) and val <= 0 and i not in required):
1573+
required_orphan = idx + op_info.n_mit_mot in required
15351574
# If the memory for this output has been pre-allocated
15361575
# before going into the scan op (by an alloc node)
15371576
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
1538-
# In case the input is still an alloc node, we
1539-
# actually have two options:
1540-
# a) the input is a set_subtensor, in that case we
1541-
# can replace the initial tensor by a slice,
1542-
# b) it is not, and we simply take a slice of it.
1543-
# TODO: commit change below with Razvan
1544-
if (
1545-
nw_inputs[offset + idx].owner
1546-
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
1547-
and nw_inputs[offset + idx].owner.op.set_instead_of_inc
1548-
and isinstance(
1549-
nw_inputs[offset + idx].owner.op.idx_list[0], slice
1550-
)
1551-
# Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value
1552-
# As it happens in set_subtensor(empty(2)[:], 0)
1553-
and not (
1554-
nw_inputs[offset + idx].ndim
1555-
> nw_inputs[offset + idx].owner.inputs[1].ndim
1556-
)
1557-
):
1558-
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
1559-
cval = pt.as_tensor_variable(val)
1560-
initl = pt.as_tensor_variable(init_l[i])
1561-
tmp_idx = pt.switch(cval < initl, cval + initl, cval - initl)
1562-
nw_input = expand_empty(_nw_input, tmp_idx)
1577+
nw_input = nw_inputs[offset + idx]
1578+
1579+
# Recreate default buffers with new size
1580+
if _is_default_scan_buffer(nw_input):
1581+
extra_size = 1 if required_orphan else val - init_l[i]
1582+
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
1583+
# Otherwise, just trim with a slice
15631584
else:
1564-
tmp = pt.as_tensor_variable(val)
1565-
initl = pt.as_tensor_variable(init_l[i])
1566-
tmp = maximum(tmp, initl)
1567-
nw_input = nw_inputs[offset + idx][:tmp]
1585+
stop = init_l[i] if required_orphan else val
1586+
nw_input = nw_input[:stop]
15681587

15691588
nw_inputs[offset + idx] = nw_input
15701589
replaced_outs.append(op_info.n_mit_mot + idx)
@@ -1588,7 +1607,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
15881607
+ op_info.n_shared_outs
15891608
)
15901609
if nw_inputs[pos] == node.inputs[0]:
1591-
nw_inputs[pos] = val
1610+
nw_inputs[pos] = 1 if required_orphan else val
15921611
odx = op_info.n_mit_mot + idx
15931612
replaced_outs.append(odx)
15941613
old_outputs += [
@@ -1600,37 +1619,22 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
16001619
],
16011620
)
16021621
]
1603-
# 3.4. Recompute inputs for everything else based on the new
1604-
# number of steps
1622+
# 3.4. Recompute inputs for everything else based on the new number of steps
16051623
if global_nsteps is not None:
16061624
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
16071625
if val == 0:
16081626
# val == 0 means that we want to keep all intermediate
16091627
# results for that state, including the initial values.
16101628
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
16111629
in_idx = offset + idx
1612-
# Number of steps in the initial state
1613-
initl = init_l[op_info.n_mit_mot + idx]
1614-
1615-
# If the initial buffer has the form
1616-
# inc_subtensor(zeros(...)[...], _nw_input)
1617-
# we want to make the zeros tensor as small as
1618-
# possible (nw_steps + initl), and call
1619-
# inc_subtensor on that instead.
1620-
# Otherwise, simply take 0:(nw_steps+initl).
1621-
if (
1622-
nw_inputs[in_idx].owner
1623-
and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor)
1624-
and isinstance(
1625-
nw_inputs[in_idx].owner.op.idx_list[0], slice
1626-
)
1627-
):
1628-
_nw_input = nw_inputs[in_idx].owner.inputs[1]
1629-
nw_input = expand_empty(_nw_input, nw_steps)
1630-
nw_inputs[in_idx] = nw_input
1630+
nw_input = nw_inputs[in_idx]
1631+
if _is_default_scan_buffer(nw_input):
1632+
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps)
16311633
else:
1632-
# FIXME: This is never used
1633-
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
1634+
# Number of steps in the initial state
1635+
init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx])
1636+
nw_input = nw_input[: (init_l_pt + nw_steps)]
1637+
nw_inputs[in_idx] = nw_input
16341638

16351639
elif (
16361640
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot

tests/scan/test_rewriting.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -1634,21 +1634,33 @@ def test_while_scan_taps_and_map(self):
16341634
assert stored_ys_steps == 2
16351635
assert stored_zs_steps == 1
16361636

1637-
def test_vector_zeros_init(self):
1637+
@pytest.mark.parametrize("val_ndim", (0, 1))
1638+
@pytest.mark.parametrize("keep_beginning", (False, True))
1639+
def test_broadcasted_init(self, keep_beginning, val_ndim):
1640+
# Regression test when the original value is a broadcasted alloc
1641+
# The scan save mem rewrite used to wrongly slice on the unbroadcasted value
1642+
val_shape = (1,) * val_ndim
1643+
val = pt.tensor("val", shape=val_shape)
1644+
val_test = np.zeros(val_shape, dtype=val.dtype)
1645+
1646+
init = pt.full((2,), val)
16381647
ys, _ = pytensor.scan(
1639-
fn=lambda ytm2, ytm1: ytm1 + ytm2,
1640-
outputs_info=[{"initial": pt.zeros(2), "taps": range(-2, 0)}],
1648+
fn=lambda *args: pt.add(*args),
1649+
outputs_info=[{"initial": init, "taps": (-2, -1)}],
16411650
n_steps=100,
16421651
)
16431652

1644-
fn = pytensor.function([], ys[-50:], mode=self.mode)
1645-
assert tuple(fn().shape) == (50,)
1653+
out = ys[:-50] if keep_beginning else ys[-50:]
1654+
fn = pytensor.function([val], out, mode=self.mode)
1655+
assert fn(val_test).shape == (50,)
16461656

16471657
# Check that rewrite worked
16481658
[scan_node] = (n for n in fn.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
16491659
_, ys_trace = scan_node.inputs
1650-
debug_fn = pytensor.function([], ys_trace.shape[0], accept_inplace=True)
1651-
assert debug_fn() == 50
1660+
buffer_size_fn = pytensor.function(
1661+
[val], ys_trace.shape[0], accept_inplace=True
1662+
)
1663+
assert buffer_size_fn(val_test) == 52 if keep_beginning else 50
16521664

16531665

16541666
def test_inner_replace_dot():

0 commit comments

Comments
 (0)