1
+ import warnings
1
2
from collections .abc import Sequence
2
3
from copy import copy
3
4
from textwrap import dedent
19
20
from pytensor .misc .frozendict import frozendict
20
21
from pytensor .printing import Printer , pprint
21
22
from pytensor .scalar import get_scalar_type
23
+ from pytensor .scalar .basic import Composite , transfer_type , upcast
22
24
from pytensor .scalar .basic import bool as scalar_bool
23
25
from pytensor .scalar .basic import identity as scalar_identity
24
- from pytensor .scalar .basic import transfer_type , upcast
25
26
from pytensor .tensor import elemwise_cgen as cgen
26
27
from pytensor .tensor import get_vector_length
27
28
from pytensor .tensor .basic import _get_vector_length , as_tensor_variable
@@ -364,6 +365,7 @@ def __init__(
364
365
self .name = name
365
366
self .scalar_op = scalar_op
366
367
self .inplace_pattern = inplace_pattern
368
+ self .ufunc = None
367
369
self .destroy_map = {o : [i ] for o , i in self .inplace_pattern .items ()}
368
370
369
371
if nfunc_spec is None :
@@ -375,14 +377,13 @@ def __init__(
375
377
def __getstate__ (self ):
376
378
d = copy (self .__dict__ )
377
379
d .pop ("ufunc" )
378
- d .pop ("nfunc" )
379
380
d .pop ("__epydoc_asRoutine" , None )
380
381
return d
381
382
382
383
def __setstate__ (self , d ):
384
+ d .pop ("nfunc" , None ) # This used to be stored in the Op, not anymore
383
385
super ().__setstate__ (d )
384
386
self .ufunc = None
385
- self .nfunc = None
386
387
self .inplace_pattern = frozendict (self .inplace_pattern )
387
388
388
389
def get_output_info (self , * inputs ):
@@ -623,31 +624,47 @@ def transform(r):
623
624
624
625
return ret
625
626
626
- def prepare_node (self , node , storage_map , compute_map , impl ):
627
- # Postpone the ufunc building to the last minutes due to:
628
- # - NumPy ufunc support only up to 32 operands (inputs and outputs)
629
- # But our c code support more.
630
- # - nfunc is reused for scipy and scipy is optional
631
- if (len (node .inputs ) + len (node .outputs )) > 32 and impl == "py" :
632
- impl = "c"
633
-
634
- if getattr (self , "nfunc_spec" , None ) and impl != "c" :
635
- self .nfunc = import_func_from_string (self .nfunc_spec [0 ])
636
-
627
+ def _create_node_ufunc (self , node ) -> None :
637
628
if (
638
- ( len ( node . inputs ) + len ( node . outputs )) <= 32
639
- and ( self . nfunc is None or self . scalar_op . nin != len ( node . inputs ))
640
- and self . ufunc is None
641
- and impl == "py"
629
+ self . nfunc_spec is not None
630
+ # Some scalar Ops like `Add` allow for a variable number of inputs,
631
+ # whereas the numpy counterpart does not.
632
+ and len ( node . inputs ) == self . nfunc_spec [ 1 ]
642
633
):
634
+ ufunc = import_func_from_string (self .nfunc_spec [0 ])
635
+ if ufunc is None :
636
+ raise ValueError (
637
+ f"Could not import ufunc { self .nfunc_spec [0 ]} for { self } "
638
+ )
639
+
640
+ elif self .ufunc is not None :
641
+ # Cached before
642
+ ufunc = self .ufunc
643
+
644
+ else :
645
+ if (len (node .inputs ) + len (node .outputs )) > 32 :
646
+ if isinstance (self .scalar_op , Composite ):
647
+ warnings .warn (
648
+ "Trying to create a Python Composite Elemwise function with more than 32 operands.\n "
649
+ "This operation should not have been introduced if the C-backend is not properly setup. "
650
+ 'Make sure it is, or disable it by setting pytensor.config.cxx = "" (empty string).\n '
651
+ "Alternatively, consider using an optional backend like NUMBA or JAX, by setting "
652
+ '`pytensor.config.mode = "NUMBA" (or "JAX").'
653
+ )
654
+ else :
655
+ warnings .warn (
656
+ f"Trying to create a Python Elemwise function for the scalar Op { self .scalar_op } "
657
+ f"with more than 32 operands. This will likely fail."
658
+ )
659
+
643
660
ufunc = np .frompyfunc (
644
661
self .scalar_op .impl , len (node .inputs ), self .scalar_op .nout
645
662
)
646
- if self .scalar_op .nin > 0 :
647
- # We can reuse it for many nodes
663
+ if self .scalar_op .nin > 0 : # Default in base class is -1
664
+ # Op has constant signature, so we can reuse ufunc for many nodes. Cache it.
648
665
self .ufunc = ufunc
649
- else :
650
- node .tag .ufunc = ufunc
666
+
667
+ node .tag .ufunc = ufunc
651
668
652
669
# Numpy ufuncs will sometimes perform operations in
653
670
# float16, in particular when the input is int8.
@@ -660,15 +677,23 @@ def prepare_node(self, node, storage_map, compute_map, impl):
660
677
661
678
# NumPy 1.10.1 raise an error when giving the signature
662
679
# when the input is complex. So add it only when inputs is int.
663
- out_dtype = node . outputs [ 0 ]. dtype
680
+ ufunc_kwargs = {}
664
681
if (
665
- out_dtype in float_dtypes
666
- and isinstance ( self . nfunc , np . ufunc )
682
+ isinstance ( ufunc , np . ufunc )
683
+ # TODO: Why check for the dtype of the first input only?
667
684
and node .inputs [0 ].dtype in discrete_dtypes
685
+ and len (node .outputs ) == 1
686
+ and node .outputs [0 ].dtype in float_dtypes
668
687
):
669
- char = np .sctype2char (out_dtype )
670
- sig = char * node .nin + "->" + char * node .nout
671
- node .tag .sig = sig
688
+ char = np .sctype2char (node .outputs [0 ].dtype )
689
+ ufunc_kwargs ["sig" ] = char * node .nin + "->" + char * node .nout
690
+
691
+ node .tag .ufunc_kwargs = ufunc_kwargs
692
+
693
+ def prepare_node (self , node , storage_map , compute_map , impl ):
694
+ if impl == "py" :
695
+ self ._create_node_ufunc (node )
696
+
672
697
node .tag .fake_node = Apply (
673
698
self .scalar_op ,
674
699
[
@@ -684,71 +709,32 @@ def prepare_node(self, node, storage_map, compute_map, impl):
684
709
self .scalar_op .prepare_node (node .tag .fake_node , None , None , impl )
685
710
686
711
def perform (self , node , inputs , output_storage ):
687
- if (len (node .inputs ) + len (node .outputs )) > 32 :
688
- # Some versions of NumPy will segfault, other will raise a
689
- # ValueError, if the number of operands in an ufunc is more than 32.
690
- # In that case, the C version should be used, or Elemwise fusion
691
- # should be disabled.
692
- # FIXME: This no longer calls the C implementation!
693
- super ().perform (node , inputs , output_storage )
712
+ ufunc = getattr (node .tag , "ufunc" , None )
713
+ if ufunc is None :
714
+ self ._create_node_ufunc (node )
715
+ ufunc = node .tag .ufunc
694
716
695
717
self ._check_runtime_broadcast (node , inputs )
696
718
697
- ufunc_args = inputs
698
- ufunc_kwargs = {}
699
- # We supported in the past calling manually op.perform.
700
- # To keep that support we need to sometimes call self.prepare_node
701
- if self .nfunc is None and self .ufunc is None :
702
- self .prepare_node (node , None , None , "py" )
703
- if self .nfunc and len (inputs ) == self .nfunc_spec [1 ]:
704
- ufunc = self .nfunc
705
- nout = self .nfunc_spec [2 ]
706
- if hasattr (node .tag , "sig" ):
707
- ufunc_kwargs ["sig" ] = node .tag .sig
708
- # Unfortunately, the else case does not allow us to
709
- # directly feed the destination arguments to the nfunc
710
- # since it sometimes requires resizing. Doing this
711
- # optimization is probably not worth the effort, since we
712
- # should normally run the C version of the Op.
713
- else :
714
- # the second calling form is used because in certain versions of
715
- # numpy the first (faster) version leads to segfaults
716
- if self .ufunc :
717
- ufunc = self .ufunc
718
- elif not hasattr (node .tag , "ufunc" ):
719
- # It happen that make_thunk isn't called, like in
720
- # get_underlying_scalar_constant_value
721
- self .prepare_node (node , None , None , "py" )
722
- # prepare_node will add ufunc to self or the tag
723
- # depending if we can reuse it or not. So we need to
724
- # test both again.
725
- if self .ufunc :
726
- ufunc = self .ufunc
727
- else :
728
- ufunc = node .tag .ufunc
729
- else :
730
- ufunc = node .tag .ufunc
731
-
732
- nout = ufunc .nout
733
-
734
- variables = ufunc (* ufunc_args , ** ufunc_kwargs )
719
+ outputs = ufunc (* inputs , ** node .tag .ufunc_kwargs )
735
720
736
- if nout == 1 :
737
- variables = [ variables ]
721
+ if not isinstance ( outputs , tuple ) :
722
+ outputs = ( outputs ,)
738
723
739
- for i , (variable , storage , nout ) in enumerate (
740
- zip (variables , output_storage , node .outputs )
724
+ for i , (out , out_storage , node_out ) in enumerate (
725
+ zip (outputs , output_storage , node .outputs )
741
726
):
742
- storage [0 ] = variable = np .asarray (variable , dtype = nout .dtype )
727
+ # Numpy frompyfunc always returns object arrays
728
+ out_storage [0 ] = out = np .asarray (out , dtype = node_out .dtype )
743
729
744
730
if i in self .inplace_pattern :
745
- odat = inputs [self .inplace_pattern [i ]]
746
- odat [...] = variable
747
- storage [0 ] = odat
731
+ inp = inputs [self .inplace_pattern [i ]]
732
+ inp [...] = out
733
+ out_storage [0 ] = inp
748
734
749
735
# numpy.real return a view!
750
- if not variable .flags .owndata :
751
- storage [0 ] = variable .copy ()
736
+ if not out .flags .owndata :
737
+ out_storage [0 ] = out .copy ()
752
738
753
739
@staticmethod
754
740
def _check_runtime_broadcast (node , inputs ):
0 commit comments