Skip to content

Commit 0ee777a

Browse files
committed
Simplify Elemwise perform method and issue informative warning when number of operands is too large.
This also clears a hard to debug error when perform method attempted to falback to the C-implementation.
1 parent 5909d93 commit 0ee777a

File tree

1 file changed

+69
-83
lines changed

1 file changed

+69
-83
lines changed

pytensor/tensor/elemwise.py

+69-83
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Sequence
23
from copy import copy
34
from textwrap import dedent
@@ -19,9 +20,9 @@
1920
from pytensor.misc.frozendict import frozendict
2021
from pytensor.printing import Printer, pprint
2122
from pytensor.scalar import get_scalar_type
23+
from pytensor.scalar.basic import Composite, transfer_type, upcast
2224
from pytensor.scalar.basic import bool as scalar_bool
2325
from pytensor.scalar.basic import identity as scalar_identity
24-
from pytensor.scalar.basic import transfer_type, upcast
2526
from pytensor.tensor import elemwise_cgen as cgen
2627
from pytensor.tensor import get_vector_length
2728
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
@@ -364,6 +365,7 @@ def __init__(
364365
self.name = name
365366
self.scalar_op = scalar_op
366367
self.inplace_pattern = inplace_pattern
368+
self.ufunc = None
367369
self.destroy_map = {o: [i] for o, i in self.inplace_pattern.items()}
368370

369371
if nfunc_spec is None:
@@ -375,14 +377,13 @@ def __init__(
375377
def __getstate__(self):
376378
d = copy(self.__dict__)
377379
d.pop("ufunc")
378-
d.pop("nfunc")
379380
d.pop("__epydoc_asRoutine", None)
380381
return d
381382

382383
def __setstate__(self, d):
384+
d.pop("nfunc", None) # This used to be stored in the Op, not anymore
383385
super().__setstate__(d)
384386
self.ufunc = None
385-
self.nfunc = None
386387
self.inplace_pattern = frozendict(self.inplace_pattern)
387388

388389
def get_output_info(self, *inputs):
@@ -623,31 +624,47 @@ def transform(r):
623624

624625
return ret
625626

626-
def prepare_node(self, node, storage_map, compute_map, impl):
627-
# Postpone the ufunc building to the last minutes due to:
628-
# - NumPy ufunc support only up to 32 operands (inputs and outputs)
629-
# But our c code support more.
630-
# - nfunc is reused for scipy and scipy is optional
631-
if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py":
632-
impl = "c"
633-
634-
if getattr(self, "nfunc_spec", None) and impl != "c":
635-
self.nfunc = import_func_from_string(self.nfunc_spec[0])
636-
627+
def _create_node_ufunc(self, node) -> None:
637628
if (
638-
(len(node.inputs) + len(node.outputs)) <= 32
639-
and (self.nfunc is None or self.scalar_op.nin != len(node.inputs))
640-
and self.ufunc is None
641-
and impl == "py"
629+
self.nfunc_spec is not None
630+
# Some scalar Ops like `Add` allow for a variable number of inputs,
631+
# whereas the numpy counterpart does not.
632+
and len(node.inputs) == self.nfunc_spec[1]
642633
):
634+
ufunc = import_func_from_string(self.nfunc_spec[0])
635+
if ufunc is None:
636+
raise ValueError(
637+
f"Could not import ufunc {self.nfunc_spec[0]} for {self}"
638+
)
639+
640+
elif self.ufunc is not None:
641+
# Cached before
642+
ufunc = self.ufunc
643+
644+
else:
645+
if (len(node.inputs) + len(node.outputs)) > 32:
646+
if isinstance(self.scalar_op, Composite):
647+
warnings.warn(
648+
"Trying to create a Python Composite Elemwise function with more than 32 operands.\n"
649+
"This operation should not have been introduced if the C-backend is not properly setup. "
650+
'Make sure it is, or disable it by setting pytensor.config.cxx = "" (empty string).\n'
651+
"Alternatively, consider using an optional backend like NUMBA or JAX, by setting "
652+
'`pytensor.config.mode = "NUMBA" (or "JAX").'
653+
)
654+
else:
655+
warnings.warn(
656+
f"Trying to create a Python Elemwise function for the scalar Op {self.scalar_op} "
657+
f"with more than 32 operands. This will likely fail."
658+
)
659+
643660
ufunc = np.frompyfunc(
644661
self.scalar_op.impl, len(node.inputs), self.scalar_op.nout
645662
)
646-
if self.scalar_op.nin > 0:
647-
# We can reuse it for many nodes
663+
if self.scalar_op.nin > 0: # Default in base class is -1
664+
# Op has constant signature, so we can reuse ufunc for many nodes. Cache it.
648665
self.ufunc = ufunc
649-
else:
650-
node.tag.ufunc = ufunc
666+
667+
node.tag.ufunc = ufunc
651668

652669
# Numpy ufuncs will sometimes perform operations in
653670
# float16, in particular when the input is int8.
@@ -660,15 +677,23 @@ def prepare_node(self, node, storage_map, compute_map, impl):
660677

661678
# NumPy 1.10.1 raise an error when giving the signature
662679
# when the input is complex. So add it only when inputs is int.
663-
out_dtype = node.outputs[0].dtype
680+
ufunc_kwargs = {}
664681
if (
665-
out_dtype in float_dtypes
666-
and isinstance(self.nfunc, np.ufunc)
682+
isinstance(ufunc, np.ufunc)
683+
# TODO: Why check for the dtype of the first input only?
667684
and node.inputs[0].dtype in discrete_dtypes
685+
and len(node.outputs) == 1
686+
and node.outputs[0].dtype in float_dtypes
668687
):
669-
char = np.sctype2char(out_dtype)
670-
sig = char * node.nin + "->" + char * node.nout
671-
node.tag.sig = sig
688+
char = np.sctype2char(node.outputs[0].dtype)
689+
ufunc_kwargs["sig"] = char * node.nin + "->" + char * node.nout
690+
691+
node.tag.ufunc_kwargs = ufunc_kwargs
692+
693+
def prepare_node(self, node, storage_map, compute_map, impl):
694+
if impl == "py":
695+
self._create_node_ufunc(node)
696+
672697
node.tag.fake_node = Apply(
673698
self.scalar_op,
674699
[
@@ -684,71 +709,32 @@ def prepare_node(self, node, storage_map, compute_map, impl):
684709
self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl)
685710

686711
def perform(self, node, inputs, output_storage):
687-
if (len(node.inputs) + len(node.outputs)) > 32:
688-
# Some versions of NumPy will segfault, other will raise a
689-
# ValueError, if the number of operands in an ufunc is more than 32.
690-
# In that case, the C version should be used, or Elemwise fusion
691-
# should be disabled.
692-
# FIXME: This no longer calls the C implementation!
693-
super().perform(node, inputs, output_storage)
712+
ufunc = getattr(node.tag, "ufunc", None)
713+
if ufunc is None:
714+
self._create_node_ufunc(node)
715+
ufunc = node.tag.ufunc
694716

695717
self._check_runtime_broadcast(node, inputs)
696718

697-
ufunc_args = inputs
698-
ufunc_kwargs = {}
699-
# We supported in the past calling manually op.perform.
700-
# To keep that support we need to sometimes call self.prepare_node
701-
if self.nfunc is None and self.ufunc is None:
702-
self.prepare_node(node, None, None, "py")
703-
if self.nfunc and len(inputs) == self.nfunc_spec[1]:
704-
ufunc = self.nfunc
705-
nout = self.nfunc_spec[2]
706-
if hasattr(node.tag, "sig"):
707-
ufunc_kwargs["sig"] = node.tag.sig
708-
# Unfortunately, the else case does not allow us to
709-
# directly feed the destination arguments to the nfunc
710-
# since it sometimes requires resizing. Doing this
711-
# optimization is probably not worth the effort, since we
712-
# should normally run the C version of the Op.
713-
else:
714-
# the second calling form is used because in certain versions of
715-
# numpy the first (faster) version leads to segfaults
716-
if self.ufunc:
717-
ufunc = self.ufunc
718-
elif not hasattr(node.tag, "ufunc"):
719-
# It happen that make_thunk isn't called, like in
720-
# get_underlying_scalar_constant_value
721-
self.prepare_node(node, None, None, "py")
722-
# prepare_node will add ufunc to self or the tag
723-
# depending if we can reuse it or not. So we need to
724-
# test both again.
725-
if self.ufunc:
726-
ufunc = self.ufunc
727-
else:
728-
ufunc = node.tag.ufunc
729-
else:
730-
ufunc = node.tag.ufunc
731-
732-
nout = ufunc.nout
733-
734-
variables = ufunc(*ufunc_args, **ufunc_kwargs)
719+
outputs = ufunc(*inputs, **node.tag.ufunc_kwargs)
735720

736-
if nout == 1:
737-
variables = [variables]
721+
if not isinstance(outputs, tuple):
722+
outputs = (outputs,)
738723

739-
for i, (variable, storage, nout) in enumerate(
740-
zip(variables, output_storage, node.outputs)
724+
for i, (out, out_storage, node_out) in enumerate(
725+
zip(outputs, output_storage, node.outputs)
741726
):
742-
storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
727+
# Numpy frompyfunc always returns object arrays
728+
out_storage[0] = out = np.asarray(out, dtype=node_out.dtype)
743729

744730
if i in self.inplace_pattern:
745-
odat = inputs[self.inplace_pattern[i]]
746-
odat[...] = variable
747-
storage[0] = odat
731+
inp = inputs[self.inplace_pattern[i]]
732+
inp[...] = out
733+
out_storage[0] = inp
748734

749735
# numpy.real return a view!
750-
if not variable.flags.owndata:
751-
storage[0] = variable.copy()
736+
if not out.flags.owndata:
737+
out_storage[0] = out.copy()
752738

753739
@staticmethod
754740
def _check_runtime_broadcast(node, inputs):

0 commit comments

Comments
 (0)