Skip to content

Commit bf3af90

Browse files
committed
Simplify Elemwise perform method and issue informative warning when number of operands is too large.
This also clears a hard to debug error when perform method attempted to falback to the C-implementation.
1 parent 5909d93 commit bf3af90

File tree

1 file changed

+69
-84
lines changed

1 file changed

+69
-84
lines changed

pytensor/tensor/elemwise.py

+69-84
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Sequence
23
from copy import copy
34
from textwrap import dedent
@@ -19,9 +20,9 @@
1920
from pytensor.misc.frozendict import frozendict
2021
from pytensor.printing import Printer, pprint
2122
from pytensor.scalar import get_scalar_type
23+
from pytensor.scalar.basic import Composite, transfer_type, upcast
2224
from pytensor.scalar.basic import bool as scalar_bool
2325
from pytensor.scalar.basic import identity as scalar_identity
24-
from pytensor.scalar.basic import transfer_type, upcast
2526
from pytensor.tensor import elemwise_cgen as cgen
2627
from pytensor.tensor import get_vector_length
2728
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
@@ -364,6 +365,7 @@ def __init__(
364365
self.name = name
365366
self.scalar_op = scalar_op
366367
self.inplace_pattern = inplace_pattern
368+
self.ufunc = None
367369
self.destroy_map = {o: [i] for o, i in self.inplace_pattern.items()}
368370

369371
if nfunc_spec is None:
@@ -375,14 +377,12 @@ def __init__(
375377
def __getstate__(self):
376378
d = copy(self.__dict__)
377379
d.pop("ufunc")
378-
d.pop("nfunc")
379-
d.pop("__epydoc_asRoutine", None)
380380
return d
381381

382382
def __setstate__(self, d):
383+
d.pop("nfunc", None) # This used to be stored in the Op, not anymore
383384
super().__setstate__(d)
384385
self.ufunc = None
385-
self.nfunc = None
386386
self.inplace_pattern = frozendict(self.inplace_pattern)
387387

388388
def get_output_info(self, *inputs):
@@ -623,31 +623,47 @@ def transform(r):
623623

624624
return ret
625625

626-
def prepare_node(self, node, storage_map, compute_map, impl):
627-
# Postpone the ufunc building to the last minutes due to:
628-
# - NumPy ufunc support only up to 32 operands (inputs and outputs)
629-
# But our c code support more.
630-
# - nfunc is reused for scipy and scipy is optional
631-
if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py":
632-
impl = "c"
633-
634-
if getattr(self, "nfunc_spec", None) and impl != "c":
635-
self.nfunc = import_func_from_string(self.nfunc_spec[0])
636-
626+
def _create_node_ufunc(self, node) -> None:
637627
if (
638-
(len(node.inputs) + len(node.outputs)) <= 32
639-
and (self.nfunc is None or self.scalar_op.nin != len(node.inputs))
640-
and self.ufunc is None
641-
and impl == "py"
628+
self.nfunc_spec is not None
629+
# Some scalar Ops like `Add` allow for a variable number of inputs,
630+
# whereas the numpy counterpart does not.
631+
and len(node.inputs) == self.nfunc_spec[1]
642632
):
633+
ufunc = import_func_from_string(self.nfunc_spec[0])
634+
if ufunc is None:
635+
raise ValueError(
636+
f"Could not import ufunc {self.nfunc_spec[0]} for {self}"
637+
)
638+
639+
elif self.ufunc is not None:
640+
# Cached before
641+
ufunc = self.ufunc
642+
643+
else:
644+
if (len(node.inputs) + len(node.outputs)) > 32:
645+
if isinstance(self.scalar_op, Composite):
646+
warnings.warn(
647+
"Trying to create a Python Composite Elemwise function with more than 32 operands.\n"
648+
"This operation should not have been introduced if the C-backend is not properly setup. "
649+
'Make sure it is, or disable it by setting pytensor.config.cxx = "" (empty string).\n'
650+
"Alternatively, consider using an optional backend like NUMBA or JAX, by setting "
651+
'`pytensor.config.mode = "NUMBA" (or "JAX").'
652+
)
653+
else:
654+
warnings.warn(
655+
f"Trying to create a Python Elemwise function for the scalar Op {self.scalar_op} "
656+
f"with more than 32 operands. This will likely fail."
657+
)
658+
643659
ufunc = np.frompyfunc(
644660
self.scalar_op.impl, len(node.inputs), self.scalar_op.nout
645661
)
646-
if self.scalar_op.nin > 0:
647-
# We can reuse it for many nodes
662+
if self.scalar_op.nin > 0: # Default in base class is -1
663+
# Op has constant signature, so we can reuse ufunc for many nodes. Cache it.
648664
self.ufunc = ufunc
649-
else:
650-
node.tag.ufunc = ufunc
665+
666+
node.tag.ufunc = ufunc
651667

652668
# Numpy ufuncs will sometimes perform operations in
653669
# float16, in particular when the input is int8.
@@ -660,15 +676,23 @@ def prepare_node(self, node, storage_map, compute_map, impl):
660676

661677
# NumPy 1.10.1 raise an error when giving the signature
662678
# when the input is complex. So add it only when inputs is int.
663-
out_dtype = node.outputs[0].dtype
679+
ufunc_kwargs = {}
664680
if (
665-
out_dtype in float_dtypes
666-
and isinstance(self.nfunc, np.ufunc)
681+
isinstance(ufunc, np.ufunc)
682+
# TODO: Why check for the dtype of the first input only?
667683
and node.inputs[0].dtype in discrete_dtypes
684+
and len(node.outputs) == 1
685+
and node.outputs[0].dtype in float_dtypes
668686
):
669-
char = np.sctype2char(out_dtype)
670-
sig = char * node.nin + "->" + char * node.nout
671-
node.tag.sig = sig
687+
char = np.sctype2char(node.outputs[0].dtype)
688+
ufunc_kwargs["sig"] = char * node.nin + "->" + char * node.nout
689+
690+
node.tag.ufunc_kwargs = ufunc_kwargs
691+
692+
def prepare_node(self, node, storage_map, compute_map, impl):
693+
if impl == "py":
694+
self._create_node_ufunc(node)
695+
672696
node.tag.fake_node = Apply(
673697
self.scalar_op,
674698
[
@@ -684,71 +708,32 @@ def prepare_node(self, node, storage_map, compute_map, impl):
684708
self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl)
685709

686710
def perform(self, node, inputs, output_storage):
687-
if (len(node.inputs) + len(node.outputs)) > 32:
688-
# Some versions of NumPy will segfault, other will raise a
689-
# ValueError, if the number of operands in an ufunc is more than 32.
690-
# In that case, the C version should be used, or Elemwise fusion
691-
# should be disabled.
692-
# FIXME: This no longer calls the C implementation!
693-
super().perform(node, inputs, output_storage)
711+
ufunc = getattr(node.tag, "ufunc", None)
712+
if ufunc is None:
713+
self._create_node_ufunc(node)
714+
ufunc = node.tag.ufunc
694715

695716
self._check_runtime_broadcast(node, inputs)
696717

697-
ufunc_args = inputs
698-
ufunc_kwargs = {}
699-
# We supported in the past calling manually op.perform.
700-
# To keep that support we need to sometimes call self.prepare_node
701-
if self.nfunc is None and self.ufunc is None:
702-
self.prepare_node(node, None, None, "py")
703-
if self.nfunc and len(inputs) == self.nfunc_spec[1]:
704-
ufunc = self.nfunc
705-
nout = self.nfunc_spec[2]
706-
if hasattr(node.tag, "sig"):
707-
ufunc_kwargs["sig"] = node.tag.sig
708-
# Unfortunately, the else case does not allow us to
709-
# directly feed the destination arguments to the nfunc
710-
# since it sometimes requires resizing. Doing this
711-
# optimization is probably not worth the effort, since we
712-
# should normally run the C version of the Op.
713-
else:
714-
# the second calling form is used because in certain versions of
715-
# numpy the first (faster) version leads to segfaults
716-
if self.ufunc:
717-
ufunc = self.ufunc
718-
elif not hasattr(node.tag, "ufunc"):
719-
# It happen that make_thunk isn't called, like in
720-
# get_underlying_scalar_constant_value
721-
self.prepare_node(node, None, None, "py")
722-
# prepare_node will add ufunc to self or the tag
723-
# depending if we can reuse it or not. So we need to
724-
# test both again.
725-
if self.ufunc:
726-
ufunc = self.ufunc
727-
else:
728-
ufunc = node.tag.ufunc
729-
else:
730-
ufunc = node.tag.ufunc
731-
732-
nout = ufunc.nout
733-
734-
variables = ufunc(*ufunc_args, **ufunc_kwargs)
718+
outputs = ufunc(*inputs, **node.tag.get("ufunc_kwargs", {}))
735719

736-
if nout == 1:
737-
variables = [variables]
720+
if not isinstance(outputs, tuple):
721+
outputs = (outputs,)
738722

739-
for i, (variable, storage, nout) in enumerate(
740-
zip(variables, output_storage, node.outputs)
723+
for i, (out, out_storage, node_out) in enumerate(
724+
zip(outputs, output_storage, node.outputs)
741725
):
742-
storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
726+
# Numpy frompyfunc always returns object arrays
727+
out_storage[0] = out = np.asarray(out, dtype=node_out.dtype)
743728

744729
if i in self.inplace_pattern:
745-
odat = inputs[self.inplace_pattern[i]]
746-
odat[...] = variable
747-
storage[0] = odat
730+
inp = inputs[self.inplace_pattern[i]]
731+
inp[...] = out
732+
out_storage[0] = inp
748733

749734
# numpy.real return a view!
750-
if not variable.flags.owndata:
751-
storage[0] = variable.copy()
735+
if not out.flags.owndata:
736+
out_storage[0] = out.copy()
752737

753738
@staticmethod
754739
def _check_runtime_broadcast(node, inputs):

0 commit comments

Comments
 (0)