Skip to content

Commit e39fda3

Browse files
committed
Make blockwise perform method node dependent
1 parent a377c22 commit e39fda3

File tree

2 files changed

+64
-26
lines changed

2 files changed

+64
-26
lines changed

pytensor/tensor/blockwise.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Sequence
2-
from copy import copy
32
from typing import Any, cast
43

54
import numpy as np
@@ -79,7 +78,6 @@ def __init__(
7978
self.name = name
8079
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
8180
self.gufunc_spec = gufunc_spec
82-
self._gufunc = None
8381
if destroy_map is not None:
8482
self.destroy_map = destroy_map
8583
if self.destroy_map != core_op.destroy_map:
@@ -91,11 +89,6 @@ def __init__(
9189

9290
super().__init__(**kwargs)
9391

94-
def __getstate__(self):
95-
d = copy(self.__dict__)
96-
d["_gufunc"] = None
97-
return d
98-
9992
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
10093
core_input_types = []
10194
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
@@ -296,32 +289,40 @@ def L_op(self, inputs, outs, ograds):
296289

297290
return rval
298291

299-
def _create_gufunc(self, node):
292+
def _create_node_gufunc(self, node) -> None:
293+
"""Define (or retrieve) the node gufunc used in `perform`.
294+
295+
If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
296+
Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
297+
298+
The gufunc is stored in the tag of the node.
299+
"""
300300
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
301301

302302
if gufunc_spec is not None:
303-
self._gufunc = import_func_from_string(gufunc_spec[0])
304-
if self._gufunc:
305-
return self._gufunc
306-
else:
303+
gufunc = import_func_from_string(gufunc_spec[0])
304+
if gufunc is None:
307305
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
308306

309-
n_outs = len(self.outputs_sig)
310-
core_node = self._create_dummy_core_node(node.inputs)
307+
else:
308+
# Wrap core_op perform method in numpy vectorize
309+
n_outs = len(self.outputs_sig)
310+
core_node = self._create_dummy_core_node(node.inputs)
311311

312-
def core_func(*inner_inputs):
313-
inner_outputs = [[None] for _ in range(n_outs)]
312+
def core_func(*inner_inputs):
313+
inner_outputs = [[None] for _ in range(n_outs)]
314314

315-
inner_inputs = [np.asarray(inp) for inp in inner_inputs]
316-
self.core_op.perform(core_node, inner_inputs, inner_outputs)
315+
inner_inputs = [np.asarray(inp) for inp in inner_inputs]
316+
self.core_op.perform(core_node, inner_inputs, inner_outputs)
317317

318-
if len(inner_outputs) == 1:
319-
return inner_outputs[0][0]
320-
else:
321-
return tuple(r[0] for r in inner_outputs)
318+
if len(inner_outputs) == 1:
319+
return inner_outputs[0][0]
320+
else:
321+
return tuple(r[0] for r in inner_outputs)
322+
323+
gufunc = np.vectorize(core_func, signature=self.signature)
322324

323-
self._gufunc = np.vectorize(core_func, signature=self.signature)
324-
return self._gufunc
325+
node.tag.gufunc = gufunc
325326

326327
def _check_runtime_broadcast(self, node, inputs):
327328
batch_ndim = self.batch_ndim(node)
@@ -340,10 +341,12 @@ def _check_runtime_broadcast(self, node, inputs):
340341
)
341342

342343
def perform(self, node, inputs, output_storage):
343-
gufunc = self._gufunc
344+
gufunc = getattr(node.tag, "gufunc", None)
344345

345346
if gufunc is None:
346-
gufunc = self._create_gufunc(node)
347+
# Cache it once per node
348+
self._create_node_gufunc(node)
349+
gufunc = node.tag.gufunc
347350

348351
self._check_runtime_broadcast(node, inputs)
349352

tests/tensor/test_blockwise.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,41 @@
2828
from pytensor.tensor.utils import _parse_gufunc_signature
2929

3030

31+
def test_perform_method_per_node():
32+
"""Confirm that Blockwise uses one perform method per node.
33+
34+
This is important if the perform method requires node information (such as dtypes)
35+
"""
36+
37+
class NodeDependentPerformOp(Op):
38+
def make_node(self, x):
39+
return Apply(self, [x], [x.type()])
40+
41+
def perform(self, node, inputs, outputs):
42+
[x] = inputs
43+
if node.inputs[0].type.dtype.startswith("float"):
44+
y = x + 1
45+
else:
46+
y = x - 1
47+
outputs[0][0] = y
48+
49+
blockwise_op = Blockwise(core_op=NodeDependentPerformOp(), signature="()->()")
50+
x = tensor("x", shape=(3,), dtype="float32")
51+
y = tensor("y", shape=(3,), dtype="int32")
52+
53+
out_x = blockwise_op(x)
54+
out_y = blockwise_op(y)
55+
fn = pytensor.function([x, y], [out_x, out_y])
56+
[op1, op2] = [node.op for node in fn.maker.fgraph.apply_nodes]
57+
# Confirm both nodes have the same Op
58+
assert op1 is blockwise_op
59+
assert op1 is op2
60+
61+
res_out_x, res_out_y = fn(np.zeros(3, dtype="float32"), np.zeros(3, dtype="int32"))
62+
np.testing.assert_array_equal(res_out_x, np.ones(3, dtype="float32"))
63+
np.testing.assert_array_equal(res_out_y, -np.ones(3, dtype="int32"))
64+
65+
3166
def test_vectorize_blockwise():
3267
mat = tensor(shape=(None, None))
3368
tns = tensor(shape=(None, None, None))

0 commit comments

Comments
 (0)