58
58
from pytensor .tensor .elemwise import DimShuffle , Elemwise
59
59
from pytensor .tensor .exceptions import NotScalarConstantError
60
60
from pytensor .tensor .math import Dot , dot , maximum , minimum
61
- from pytensor .tensor .rewriting .basic import constant_folding , local_useless_switch
61
+ from pytensor .tensor .rewriting .basic import (
62
+ broadcasted_by ,
63
+ constant_folding ,
64
+ local_useless_switch ,
65
+ )
62
66
from pytensor .tensor .rewriting .elemwise import local_upcast_elemwise_constant_inputs
63
67
from pytensor .tensor .rewriting .math import local_abs_merge , local_mul_switch_sink
64
68
from pytensor .tensor .shape import shape
@@ -1182,6 +1186,44 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
1182
1186
return subtensor_merge_replacements
1183
1187
1184
1188
1189
+ def _is_default_scan_buffer (x : TensorVariable ) -> bool :
1190
+ node = x .owner
1191
+
1192
+ if node is None :
1193
+ return False
1194
+
1195
+ op = node .op
1196
+ if not (
1197
+ isinstance (op , IncSubtensor )
1198
+ and op .set_instead_of_inc
1199
+ and op .idx_list == [slice (None , ps .int64 )]
1200
+ ):
1201
+ return False
1202
+
1203
+ x , y , * _ = node .inputs
1204
+ if not (x .owner is not None and isinstance (x .owner .op , AllocEmpty )):
1205
+ return False
1206
+
1207
+ # The value may have been broadcast to fill in the initial taps.
1208
+ # If the user specified outputs as:
1209
+ # x = scalar(); init = alloc(x, 2);
1210
+ # outputs_info=[init, taps=(-2, -1)]
1211
+ # Scan will generate an initial buffer that looks like
1212
+ # alloc_empty(2 + nsteps)[:2].set(alloc(x, 2))
1213
+ # PyTensor will then rewrite it as:
1214
+ # alloc_empty(2 + nsteps)[:2].set(x)
1215
+ # When the initial value (x) is being broadcast by the set_subtensor
1216
+ # we can't recreate a newly sized buffer working with x alone
1217
+ # We want to check that:
1218
+ # 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
1219
+ # But due to laziness we use the slightly more conservative check:
1220
+ # 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
1221
+ if broadcasted_by (y , x ):
1222
+ return False
1223
+
1224
+ return True
1225
+
1226
+
1185
1227
def scan_save_mem_rewrite (fgraph , node , backend_supports_output_pre_allocation : bool ):
1186
1228
r"""Graph optimizer that reduces scan memory consumption.
1187
1229
@@ -1520,51 +1562,28 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
1520
1562
1521
1563
# 3.2 check orphane outputs to see if we can eliminate any
1522
1564
required , not_required = scan_can_remove_outs (node .op , orphane_outs )
1523
- # 3.3. compose replace pairs for those nodes that need not
1524
- # to store everything in memory ( or ar orphane and required
1525
- # by the inner function .. )
1565
+
1566
+ # 3.3. compose replace pairs for those nodes that need not store everything in memory
1567
+ # (or ar orphan but required by the inner function)
1526
1568
replaced_outs = []
1527
1569
offset = 1 + op_info .n_seqs + op_info .n_mit_mot
1528
- for idx , _val in enumerate (store_steps [op_info .n_mit_mot :]):
1570
+ for idx , val in enumerate (store_steps [op_info .n_mit_mot :]):
1529
1571
i = idx + op_info .n_mit_mot
1530
- if not (isinstance (_val , int ) and _val <= 0 and i not in required ):
1531
- if idx + op_info .n_mit_mot in required :
1532
- val = 1
1533
- else :
1534
- val = _val
1572
+ if not (isinstance (val , int ) and val <= 0 and i not in required ):
1573
+ required_orphan = idx + op_info .n_mit_mot in required
1535
1574
# If the memory for this output has been pre-allocated
1536
1575
# before going into the scan op (by an alloc node)
1537
1576
if idx < op_info .n_mit_sot + op_info .n_sit_sot :
1538
- # In case the input is still an alloc node, we
1539
- # actually have two options:
1540
- # a) the input is a set_subtensor, in that case we
1541
- # can replace the initial tensor by a slice,
1542
- # b) it is not, and we simply take a slice of it.
1543
- # TODO: commit change below with Razvan
1544
- if (
1545
- nw_inputs [offset + idx ].owner
1546
- and isinstance (nw_inputs [offset + idx ].owner .op , IncSubtensor )
1547
- and nw_inputs [offset + idx ].owner .op .set_instead_of_inc
1548
- and isinstance (
1549
- nw_inputs [offset + idx ].owner .op .idx_list [0 ], slice
1550
- )
1551
- # Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value
1552
- # As it happens in set_subtensor(empty(2)[:], 0)
1553
- and not (
1554
- nw_inputs [offset + idx ].ndim
1555
- > nw_inputs [offset + idx ].owner .inputs [1 ].ndim
1556
- )
1557
- ):
1558
- _nw_input = nw_inputs [offset + idx ].owner .inputs [1 ]
1559
- cval = pt .as_tensor_variable (val )
1560
- initl = pt .as_tensor_variable (init_l [i ])
1561
- tmp_idx = pt .switch (cval < initl , cval + initl , cval - initl )
1562
- nw_input = expand_empty (_nw_input , tmp_idx )
1577
+ nw_input = nw_inputs [offset + idx ]
1578
+
1579
+ # Recreate default buffers with new size
1580
+ if _is_default_scan_buffer (nw_input ):
1581
+ extra_size = 1 if required_orphan else val - init_l [i ]
1582
+ nw_input = expand_empty (nw_input .owner .inputs [1 ], extra_size )
1583
+ # Otherwise, just trim with a slice
1563
1584
else :
1564
- tmp = pt .as_tensor_variable (val )
1565
- initl = pt .as_tensor_variable (init_l [i ])
1566
- tmp = maximum (tmp , initl )
1567
- nw_input = nw_inputs [offset + idx ][:tmp ]
1585
+ stop = init_l [i ] if required_orphan else val
1586
+ nw_input = nw_input [:stop ]
1568
1587
1569
1588
nw_inputs [offset + idx ] = nw_input
1570
1589
replaced_outs .append (op_info .n_mit_mot + idx )
@@ -1588,7 +1607,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
1588
1607
+ op_info .n_shared_outs
1589
1608
)
1590
1609
if nw_inputs [pos ] == node .inputs [0 ]:
1591
- nw_inputs [pos ] = val
1610
+ nw_inputs [pos ] = 1 if required_orphan else val
1592
1611
odx = op_info .n_mit_mot + idx
1593
1612
replaced_outs .append (odx )
1594
1613
old_outputs += [
@@ -1600,37 +1619,22 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
1600
1619
],
1601
1620
)
1602
1621
]
1603
- # 3.4. Recompute inputs for everything else based on the new
1604
- # number of steps
1622
+ # 3.4. Recompute inputs for everything else based on the new number of steps
1605
1623
if global_nsteps is not None :
1606
1624
for idx , val in enumerate (store_steps [op_info .n_mit_mot :]):
1607
1625
if val == 0 :
1608
1626
# val == 0 means that we want to keep all intermediate
1609
1627
# results for that state, including the initial values.
1610
1628
if idx < op_info .n_mit_sot + op_info .n_sit_sot :
1611
1629
in_idx = offset + idx
1612
- # Number of steps in the initial state
1613
- initl = init_l [op_info .n_mit_mot + idx ]
1614
-
1615
- # If the initial buffer has the form
1616
- # inc_subtensor(zeros(...)[...], _nw_input)
1617
- # we want to make the zeros tensor as small as
1618
- # possible (nw_steps + initl), and call
1619
- # inc_subtensor on that instead.
1620
- # Otherwise, simply take 0:(nw_steps+initl).
1621
- if (
1622
- nw_inputs [in_idx ].owner
1623
- and isinstance (nw_inputs [in_idx ].owner .op , IncSubtensor )
1624
- and isinstance (
1625
- nw_inputs [in_idx ].owner .op .idx_list [0 ], slice
1626
- )
1627
- ):
1628
- _nw_input = nw_inputs [in_idx ].owner .inputs [1 ]
1629
- nw_input = expand_empty (_nw_input , nw_steps )
1630
- nw_inputs [in_idx ] = nw_input
1630
+ nw_input = nw_inputs [in_idx ]
1631
+ if _is_default_scan_buffer (nw_input ):
1632
+ nw_input = expand_empty (nw_input .owner .inputs [1 ], nw_steps )
1631
1633
else :
1632
- # FIXME: This is never used
1633
- nw_input = nw_inputs [in_idx ][: (initl + nw_steps )]
1634
+ # Number of steps in the initial state
1635
+ init_l_pt = pt .as_tensor (init_l [op_info .n_mit_mot + idx ])
1636
+ nw_input = nw_input [: (init_l_pt + nw_steps )]
1637
+ nw_inputs [in_idx ] = nw_input
1634
1638
1635
1639
elif (
1636
1640
idx < op_info .n_mit_sot + op_info .n_sit_sot + op_info .n_nit_sot
0 commit comments