1
+ import warnings
1
2
from collections .abc import Sequence
2
3
from copy import copy
3
4
from textwrap import dedent
19
20
from pytensor .misc .frozendict import frozendict
20
21
from pytensor .printing import Printer , pprint
21
22
from pytensor .scalar import get_scalar_type
23
+ from pytensor .scalar .basic import Composite , transfer_type , upcast
22
24
from pytensor .scalar .basic import bool as scalar_bool
23
25
from pytensor .scalar .basic import identity as scalar_identity
24
- from pytensor .scalar .basic import transfer_type , upcast
25
26
from pytensor .tensor import elemwise_cgen as cgen
26
27
from pytensor .tensor import get_vector_length
27
28
from pytensor .tensor .basic import _get_vector_length , as_tensor_variable
@@ -364,6 +365,7 @@ def __init__(
364
365
self .name = name
365
366
self .scalar_op = scalar_op
366
367
self .inplace_pattern = inplace_pattern
368
+ self .ufunc = None
367
369
self .destroy_map = {o : [i ] for o , i in self .inplace_pattern .items ()}
368
370
369
371
if nfunc_spec is None :
@@ -375,14 +377,12 @@ def __init__(
375
377
def __getstate__ (self ):
376
378
d = copy (self .__dict__ )
377
379
d .pop ("ufunc" )
378
- d .pop ("nfunc" )
379
- d .pop ("__epydoc_asRoutine" , None )
380
380
return d
381
381
382
382
def __setstate__ (self , d ):
383
+ d .pop ("nfunc" , None ) # This used to be stored in the Op, not anymore
383
384
super ().__setstate__ (d )
384
385
self .ufunc = None
385
- self .nfunc = None
386
386
self .inplace_pattern = frozendict (self .inplace_pattern )
387
387
388
388
def get_output_info (self , * inputs ):
@@ -623,31 +623,47 @@ def transform(r):
623
623
624
624
return ret
625
625
626
- def prepare_node (self , node , storage_map , compute_map , impl ):
627
- # Postpone the ufunc building to the last minutes due to:
628
- # - NumPy ufunc support only up to 32 operands (inputs and outputs)
629
- # But our c code support more.
630
- # - nfunc is reused for scipy and scipy is optional
631
- if (len (node .inputs ) + len (node .outputs )) > 32 and impl == "py" :
632
- impl = "c"
633
-
634
- if getattr (self , "nfunc_spec" , None ) and impl != "c" :
635
- self .nfunc = import_func_from_string (self .nfunc_spec [0 ])
636
-
626
+ def _create_node_ufunc (self , node ) -> None :
637
627
if (
638
- ( len ( node . inputs ) + len ( node . outputs )) <= 32
639
- and ( self . nfunc is None or self . scalar_op . nin != len ( node . inputs ))
640
- and self . ufunc is None
641
- and impl == "py"
628
+ self . nfunc_spec is not None
629
+ # Some scalar Ops like `Add` allow for a variable number of inputs,
630
+ # whereas the numpy counterpart does not.
631
+ and len ( node . inputs ) == self . nfunc_spec [ 1 ]
642
632
):
633
+ ufunc = import_func_from_string (self .nfunc_spec [0 ])
634
+ if ufunc is None :
635
+ raise ValueError (
636
+ f"Could not import ufunc { self .nfunc_spec [0 ]} for { self } "
637
+ )
638
+
639
+ elif self .ufunc is not None :
640
+ # Cached before
641
+ ufunc = self .ufunc
642
+
643
+ else :
644
+ if (len (node .inputs ) + len (node .outputs )) > 32 :
645
+ if isinstance (self .scalar_op , Composite ):
646
+ warnings .warn (
647
+ "Trying to create a Python Composite Elemwise function with more than 32 operands.\n "
648
+ "This operation should not have been introduced if the C-backend is not properly setup. "
649
+ 'Make sure it is, or disable it by setting pytensor.config.cxx = "" (empty string).\n '
650
+ "Alternatively, consider using an optional backend like NUMBA or JAX, by setting "
651
+ '`pytensor.config.mode = "NUMBA" (or "JAX").'
652
+ )
653
+ else :
654
+ warnings .warn (
655
+ f"Trying to create a Python Elemwise function for the scalar Op { self .scalar_op } "
656
+ f"with more than 32 operands. This will likely fail."
657
+ )
658
+
643
659
ufunc = np .frompyfunc (
644
660
self .scalar_op .impl , len (node .inputs ), self .scalar_op .nout
645
661
)
646
- if self .scalar_op .nin > 0 :
647
- # We can reuse it for many nodes
662
+ if self .scalar_op .nin > 0 : # Default in base class is -1
663
+ # Op has constant signature, so we can reuse ufunc for many nodes. Cache it.
648
664
self .ufunc = ufunc
649
- else :
650
- node .tag .ufunc = ufunc
665
+
666
+ node .tag .ufunc = ufunc
651
667
652
668
# Numpy ufuncs will sometimes perform operations in
653
669
# float16, in particular when the input is int8.
@@ -660,15 +676,23 @@ def prepare_node(self, node, storage_map, compute_map, impl):
660
676
661
677
# NumPy 1.10.1 raise an error when giving the signature
662
678
# when the input is complex. So add it only when inputs is int.
663
- out_dtype = node . outputs [ 0 ]. dtype
679
+ ufunc_kwargs = {}
664
680
if (
665
- out_dtype in float_dtypes
666
- and isinstance ( self . nfunc , np . ufunc )
681
+ isinstance ( ufunc , np . ufunc )
682
+ # TODO: Why check for the dtype of the first input only?
667
683
and node .inputs [0 ].dtype in discrete_dtypes
684
+ and len (node .outputs ) == 1
685
+ and node .outputs [0 ].dtype in float_dtypes
668
686
):
669
- char = np .sctype2char (out_dtype )
670
- sig = char * node .nin + "->" + char * node .nout
671
- node .tag .sig = sig
687
+ char = np .sctype2char (node .outputs [0 ].dtype )
688
+ ufunc_kwargs ["sig" ] = char * node .nin + "->" + char * node .nout
689
+
690
+ node .tag .ufunc_kwargs = ufunc_kwargs
691
+
692
+ def prepare_node (self , node , storage_map , compute_map , impl ):
693
+ if impl == "py" :
694
+ self ._create_node_ufunc (node )
695
+
672
696
node .tag .fake_node = Apply (
673
697
self .scalar_op ,
674
698
[
@@ -684,71 +708,32 @@ def prepare_node(self, node, storage_map, compute_map, impl):
684
708
self .scalar_op .prepare_node (node .tag .fake_node , None , None , impl )
685
709
686
710
def perform (self , node , inputs , output_storage ):
687
- if (len (node .inputs ) + len (node .outputs )) > 32 :
688
- # Some versions of NumPy will segfault, other will raise a
689
- # ValueError, if the number of operands in an ufunc is more than 32.
690
- # In that case, the C version should be used, or Elemwise fusion
691
- # should be disabled.
692
- # FIXME: This no longer calls the C implementation!
693
- super ().perform (node , inputs , output_storage )
711
+ ufunc = getattr (node .tag , "ufunc" , None )
712
+ if ufunc is None :
713
+ self ._create_node_ufunc (node )
714
+ ufunc = node .tag .ufunc
694
715
695
716
self ._check_runtime_broadcast (node , inputs )
696
717
697
- ufunc_args = inputs
698
- ufunc_kwargs = {}
699
- # We supported in the past calling manually op.perform.
700
- # To keep that support we need to sometimes call self.prepare_node
701
- if self .nfunc is None and self .ufunc is None :
702
- self .prepare_node (node , None , None , "py" )
703
- if self .nfunc and len (inputs ) == self .nfunc_spec [1 ]:
704
- ufunc = self .nfunc
705
- nout = self .nfunc_spec [2 ]
706
- if hasattr (node .tag , "sig" ):
707
- ufunc_kwargs ["sig" ] = node .tag .sig
708
- # Unfortunately, the else case does not allow us to
709
- # directly feed the destination arguments to the nfunc
710
- # since it sometimes requires resizing. Doing this
711
- # optimization is probably not worth the effort, since we
712
- # should normally run the C version of the Op.
713
- else :
714
- # the second calling form is used because in certain versions of
715
- # numpy the first (faster) version leads to segfaults
716
- if self .ufunc :
717
- ufunc = self .ufunc
718
- elif not hasattr (node .tag , "ufunc" ):
719
- # It happen that make_thunk isn't called, like in
720
- # get_underlying_scalar_constant_value
721
- self .prepare_node (node , None , None , "py" )
722
- # prepare_node will add ufunc to self or the tag
723
- # depending if we can reuse it or not. So we need to
724
- # test both again.
725
- if self .ufunc :
726
- ufunc = self .ufunc
727
- else :
728
- ufunc = node .tag .ufunc
729
- else :
730
- ufunc = node .tag .ufunc
731
-
732
- nout = ufunc .nout
733
-
734
- variables = ufunc (* ufunc_args , ** ufunc_kwargs )
718
+ outputs = ufunc (* inputs , ** node .tag .get ("ufunc_kwargs" , {}))
735
719
736
- if nout == 1 :
737
- variables = [ variables ]
720
+ if not isinstance ( outputs , tuple ) :
721
+ outputs = ( outputs ,)
738
722
739
- for i , (variable , storage , nout ) in enumerate (
740
- zip (variables , output_storage , node .outputs )
723
+ for i , (out , out_storage , node_out ) in enumerate (
724
+ zip (outputs , output_storage , node .outputs )
741
725
):
742
- storage [0 ] = variable = np .asarray (variable , dtype = nout .dtype )
726
+ # Numpy frompyfunc always returns object arrays
727
+ out_storage [0 ] = out = np .asarray (out , dtype = node_out .dtype )
743
728
744
729
if i in self .inplace_pattern :
745
- odat = inputs [self .inplace_pattern [i ]]
746
- odat [...] = variable
747
- storage [0 ] = odat
730
+ inp = inputs [self .inplace_pattern [i ]]
731
+ inp [...] = out
732
+ out_storage [0 ] = inp
748
733
749
734
# numpy.real return a view!
750
- if not variable .flags .owndata :
751
- storage [0 ] = variable .copy ()
735
+ if not out .flags .owndata :
736
+ out_storage [0 ] = out .copy ()
752
737
753
738
@staticmethod
754
739
def _check_runtime_broadcast (node , inputs ):
0 commit comments