1
1
from collections .abc import Sequence
2
- from copy import copy
3
2
from typing import Any , cast
4
3
5
4
import numpy as np
@@ -79,7 +78,6 @@ def __init__(
79
78
self .name = name
80
79
self .inputs_sig , self .outputs_sig = _parse_gufunc_signature (signature )
81
80
self .gufunc_spec = gufunc_spec
82
- self ._gufunc = None
83
81
if destroy_map is not None :
84
82
self .destroy_map = destroy_map
85
83
if self .destroy_map != core_op .destroy_map :
@@ -91,11 +89,6 @@ def __init__(
91
89
92
90
super ().__init__ (** kwargs )
93
91
94
- def __getstate__ (self ):
95
- d = copy (self .__dict__ )
96
- d ["_gufunc" ] = None
97
- return d
98
-
99
92
def _create_dummy_core_node (self , inputs : Sequence [TensorVariable ]) -> Apply :
100
93
core_input_types = []
101
94
for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig )):
@@ -296,32 +289,40 @@ def L_op(self, inputs, outs, ograds):
296
289
297
290
return rval
298
291
299
- def _create_gufunc (self , node ):
292
+ def _create_node_gufunc (self , node ) -> None :
293
+ """Define (or retrieve) the node gufunc used in `perform`.
294
+
295
+ If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
296
+ Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
297
+
298
+ The gufunc is stored in the tag of the node.
299
+ """
300
300
gufunc_spec = self .gufunc_spec or getattr (self .core_op , "gufunc_spec" , None )
301
301
302
302
if gufunc_spec is not None :
303
- self ._gufunc = import_func_from_string (gufunc_spec [0 ])
304
- if self ._gufunc :
305
- return self ._gufunc
306
- else :
303
+ gufunc = import_func_from_string (gufunc_spec [0 ])
304
+ if gufunc is None :
307
305
raise ValueError (f"Could not import gufunc { gufunc_spec [0 ]} for { self } " )
308
306
309
- n_outs = len (self .outputs_sig )
310
- core_node = self ._create_dummy_core_node (node .inputs )
307
+ else :
308
+ # Wrap core_op perform method in numpy vectorize
309
+ n_outs = len (self .outputs_sig )
310
+ core_node = self ._create_dummy_core_node (node .inputs )
311
311
312
- def core_func (* inner_inputs ):
313
- inner_outputs = [[None ] for _ in range (n_outs )]
312
+ def core_func (* inner_inputs ):
313
+ inner_outputs = [[None ] for _ in range (n_outs )]
314
314
315
- inner_inputs = [np .asarray (inp ) for inp in inner_inputs ]
316
- self .core_op .perform (core_node , inner_inputs , inner_outputs )
315
+ inner_inputs = [np .asarray (inp ) for inp in inner_inputs ]
316
+ self .core_op .perform (core_node , inner_inputs , inner_outputs )
317
317
318
- if len (inner_outputs ) == 1 :
319
- return inner_outputs [0 ][0 ]
320
- else :
321
- return tuple (r [0 ] for r in inner_outputs )
318
+ if len (inner_outputs ) == 1 :
319
+ return inner_outputs [0 ][0 ]
320
+ else :
321
+ return tuple (r [0 ] for r in inner_outputs )
322
+
323
+ gufunc = np .vectorize (core_func , signature = self .signature )
322
324
323
- self ._gufunc = np .vectorize (core_func , signature = self .signature )
324
- return self ._gufunc
325
+ node .tag .gufunc = gufunc
325
326
326
327
def _check_runtime_broadcast (self , node , inputs ):
327
328
batch_ndim = self .batch_ndim (node )
@@ -340,10 +341,12 @@ def _check_runtime_broadcast(self, node, inputs):
340
341
)
341
342
342
343
def perform (self , node , inputs , output_storage ):
343
- gufunc = self . _gufunc
344
+ gufunc = getattr ( node . tag , "gufunc" , None )
344
345
345
346
if gufunc is None :
346
- gufunc = self ._create_gufunc (node )
347
+ # Cache it once per node
348
+ self ._create_node_gufunc (node )
349
+ gufunc = node .tag .gufunc
347
350
348
351
self ._check_runtime_broadcast (node , inputs )
349
352
0 commit comments