@@ -326,8 +326,8 @@ class Function:
326
326
def __init__ (
327
327
self ,
328
328
vm : "VM" ,
329
- input_storage ,
330
- output_storage ,
329
+ input_storage : list [ Container ] ,
330
+ output_storage : list [ Container ] ,
331
331
indices ,
332
332
outputs ,
333
333
defaults ,
@@ -372,7 +372,6 @@ def __init__(
372
372
name
373
373
A string name.
374
374
"""
375
- # TODO: Rename to `vm`
376
375
self .vm = vm
377
376
self .input_storage = input_storage
378
377
self .output_storage = output_storage
@@ -388,31 +387,49 @@ def __init__(
388
387
self .nodes_with_inner_function = []
389
388
self .output_keys = output_keys
390
389
391
- # See if we have any mutable / borrow inputs
392
- # TODO: this only need to be set if there is more than one input
393
- self ._check_for_aliased_inputs = False
394
- for i in maker .inputs :
395
- # If the input is a shared variable, the memory region is
396
- # under PyTensor control and so we don't need to check if it
397
- # is aliased as we never do that.
398
- if (
399
- isinstance (i , In )
400
- and not i .shared
401
- and (getattr (i , "borrow" , False ) or getattr (i , "mutable" , False ))
390
+ assert len (self .input_storage ) == len (self .maker .fgraph .inputs )
391
+ assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
392
+
393
+ # Group indexes of inputs that are potentially aliased to each other
394
+ # Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
395
+ # even though there could be two distinct types that use the same kinds of underlying objects.
396
+ potential_aliased_input_groups = []
397
+ for inp in maker .inputs :
398
+ # If the input is a shared variable, the memory region is under PyTensor control
399
+ # and can't be aliased.
400
+ if not (
401
+ isinstance (inp , In )
402
+ and inp .borrow
403
+ and not inp .shared
404
+ and hasattr (inp .variable .type , "may_share_memory" )
402
405
):
403
- self ._check_for_aliased_inputs = True
404
- break
406
+ continue
407
+
408
+ for group in potential_aliased_input_groups :
409
+ # If one is super of the other, that means one could be replaced by the other
410
+ if any (
411
+ inp .variable .type .is_super (other_inp .variable .type )
412
+ or other_inp .variable .type .is_super (inp .variable .type )
413
+ for other_inp in group
414
+ ):
415
+ group .append (inp )
416
+ break
417
+ else : # no break
418
+ # Input makes a new group
419
+ potential_aliased_input_groups .append ([inp ])
420
+
421
+ # Potential aliased inputs are those that belong to the same group
422
+ self ._potential_aliased_input_groups : tuple [tuple [int , ...], ...] = tuple (
423
+ tuple (maker .inputs .index (inp ) for inp in group )
424
+ for group in potential_aliased_input_groups
425
+ if len (group ) > 1
426
+ )
405
427
406
428
# We will be popping stuff off this `containers` object. It is a copy.
407
429
containers = list (self .input_storage )
408
430
finder = {}
409
431
inv_finder = {}
410
432
411
- def distribute (indices , cs , value ):
412
- input .distribute (value , indices , cs )
413
- for c in cs :
414
- c .provided += 1
415
-
416
433
# Store the list of names of named inputs.
417
434
named_inputs = []
418
435
# Count the number of un-named inputs.
@@ -777,6 +794,13 @@ def checkSV(sv_ori, sv_rpl):
777
794
f_cpy .maker .fgraph .name = name
778
795
return f_cpy
779
796
797
+ def _restore_defaults (self ):
798
+ for i , (required , refeed , value ) in enumerate (self .defaults ):
799
+ if refeed :
800
+ if isinstance (value , Container ):
801
+ value = value .storage [0 ]
802
+ self [i ] = value
803
+
780
804
def __call__ (self , * args , ** kwargs ):
781
805
"""
782
806
Evaluates value of a function on given arguments.
@@ -805,52 +829,43 @@ def __call__(self, *args, **kwargs):
805
829
List of outputs on indices/keys from ``output_subset`` or all of them,
806
830
if ``output_subset`` is not passed.
807
831
"""
808
-
809
- def restore_defaults ():
810
- for i , (required , refeed , value ) in enumerate (self .defaults ):
811
- if refeed :
812
- if isinstance (value , Container ):
813
- value = value .storage [0 ]
814
- self [i ] = value
815
-
832
+ input_storage = self .input_storage
816
833
profile = self .profile
817
- t0 = time .perf_counter ()
834
+
835
+ if profile :
836
+ t0 = time .perf_counter ()
818
837
819
838
output_subset = kwargs .pop ("output_subset" , None )
820
839
if output_subset is not None and self .output_keys is not None :
821
840
output_subset = [self .output_keys .index (key ) for key in output_subset ]
822
841
823
842
# Reinitialize each container's 'provided' counter
824
843
if self .trust_input :
825
- i = 0
826
- for arg in args :
827
- s = self .input_storage [i ]
828
- s .storage [0 ] = arg
829
- i += 1
844
+ for arg_container , arg in zip (input_storage , args , strict = False ):
845
+ arg_container .storage [0 ] = arg
830
846
else :
831
- for c in self . input_storage :
832
- c .provided = 0
847
+ for arg_container in input_storage :
848
+ arg_container .provided = 0
833
849
834
- if len (args ) + len (kwargs ) > len (self . input_storage ):
850
+ if len (args ) + len (kwargs ) > len (input_storage ):
835
851
raise TypeError ("Too many parameter passed to pytensor function" )
836
852
837
853
# Set positional arguments
838
- i = 0
839
- for arg in args :
840
- # TODO: provide a option for skipping the filter if we really
841
- # want speed.
842
- s = self .input_storage [i ]
843
- # see this emails for a discuation about None as input
854
+ for arg_container , arg in zip (input_storage , args , strict = False ):
855
+ # See discussion about None as input
844
856
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
845
857
if arg is None :
846
- s .storage [0 ] = arg
858
+ arg_container .storage [0 ] = arg
847
859
else :
848
860
try :
849
- s .storage [0 ] = s .type .filter (
850
- arg , strict = s .strict , allow_downcast = s .allow_downcast
861
+ arg_container .storage [0 ] = arg_container .type .filter (
862
+ arg ,
863
+ strict = arg_container .strict ,
864
+ allow_downcast = arg_container .allow_downcast ,
851
865
)
852
866
853
867
except Exception as e :
868
+ i = input_storage .index (arg_container )
854
869
function_name = "pytensor function"
855
870
argument_name = "argument"
856
871
if self .name :
@@ -875,93 +890,74 @@ def restore_defaults():
875
890
+ function_name
876
891
+ f" at index { int (i )} (0-based). { where } "
877
892
) + e .args
878
- restore_defaults ()
893
+ self . _restore_defaults ()
879
894
raise
880
- s .provided += 1
881
- i += 1
895
+ arg_container .provided += 1
882
896
883
897
# Set keyword arguments
884
898
if kwargs : # for speed, skip the items for empty kwargs
885
899
for k , arg in kwargs .items ():
886
900
self [k ] = arg
887
901
888
- if (
889
- not self .trust_input
890
- and
891
- # The getattr is only needed for old pickle
892
- getattr (self , "_check_for_aliased_inputs" , True )
893
- ):
902
+ if not self .trust_input :
894
903
# Collect aliased inputs among the storage space
895
- args_share_memory = []
896
- for i in range (len (self .input_storage )):
897
- i_var = self .maker .inputs [i ].variable
898
- i_val = self .input_storage [i ].storage [0 ]
899
- if hasattr (i_var .type , "may_share_memory" ):
900
- is_aliased = False
901
- for j in range (len (args_share_memory )):
902
- group_j = zip (
903
- [
904
- self .maker .inputs [k ].variable
905
- for k in args_share_memory [j ]
906
- ],
907
- [
908
- self .input_storage [k ].storage [0 ]
909
- for k in args_share_memory [j ]
910
- ],
911
- )
904
+ for potential_group in self ._potential_aliased_input_groups :
905
+ args_share_memory : list [list [int ]] = []
906
+ for i in potential_group :
907
+ i_type = self .maker .inputs [i ].variable .type
908
+ i_val = input_storage [i ].storage [0 ]
909
+
910
+ # Check if value is aliased with any of the values in one of the groups
911
+ for j_group in args_share_memory :
912
912
if any (
913
- (
914
- var .type is i_var .type
915
- and var .type .may_share_memory (val , i_val )
916
- )
917
- for (var , val ) in group_j
913
+ i_type .may_share_memory (input_storage [j ].storage [0 ], i_val )
914
+ for j in j_group
918
915
):
919
- is_aliased = True
920
- args_share_memory [j ].append (i )
916
+ j_group .append (i )
921
917
break
922
-
923
- if not is_aliased :
918
+ else : # no break
919
+ # Create a new group
924
920
args_share_memory .append ([i ])
925
921
926
- # Check for groups of more than one argument that share memory
927
- for group in args_share_memory :
928
- if len (group ) > 1 :
929
- # copy all but the first
930
- for j in group [1 :]:
931
- self . input_storage [j ].storage [0 ] = copy .copy (
932
- self . input_storage [j ].storage [0 ]
933
- )
922
+ # Check for groups of more than one argument that share memory
923
+ for group in args_share_memory :
924
+ if len (group ) > 1 :
925
+ # copy all but the first
926
+ for i in group [1 :]:
927
+ input_storage [i ].storage [0 ] = copy .copy (
928
+ input_storage [i ].storage [0 ]
929
+ )
934
930
935
- # Check if inputs are missing, or if inputs were set more than once, or
936
- # if we tried to provide inputs that are supposed to be implicit.
937
- if not self .trust_input :
938
- for c in self .input_storage :
939
- if c .required and not c .provided :
940
- restore_defaults ()
931
+ # Check if inputs are missing, or if inputs were set more than once, or
932
+ # if we tried to provide inputs that are supposed to be implicit.
933
+ for arg_container in input_storage :
934
+ if arg_container .required and not arg_container .provided :
935
+ self ._restore_defaults ()
941
936
raise TypeError (
942
- f"Missing required input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
937
+ f"Missing required input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
943
938
)
944
- if c .provided > 1 :
945
- restore_defaults ()
939
+ if arg_container .provided > 1 :
940
+ self . _restore_defaults ()
946
941
raise TypeError (
947
- f"Multiple values for input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
942
+ f"Multiple values for input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
948
943
)
949
- if c .implicit and c .provided > 0 :
950
- restore_defaults ()
944
+ if arg_container .implicit and arg_container .provided > 0 :
945
+ self . _restore_defaults ()
951
946
raise TypeError (
952
- f"Tried to provide value for implicit input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
947
+ f"Tried to provide value for implicit input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
953
948
)
954
949
955
950
# Do the actual work
956
- t0_fn = time .perf_counter ()
951
+ if profile :
952
+ t0_fn = time .perf_counter ()
957
953
try :
958
954
outputs = (
959
955
self .vm ()
960
956
if output_subset is None
961
957
else self .vm (output_subset = output_subset )
962
958
)
963
959
except Exception :
964
- restore_defaults ()
960
+ self . _restore_defaults ()
965
961
if hasattr (self .vm , "position_of_error" ):
966
962
# this is a new vm-provided function or c linker
967
963
# they need this because the exception manipulation
@@ -979,26 +975,24 @@ def restore_defaults():
979
975
# old-style linkers raise their own exceptions
980
976
raise
981
977
982
- dt_fn = time .perf_counter () - t0_fn
983
- self .maker .mode .fn_time += dt_fn
984
978
if profile :
979
+ dt_fn = time .perf_counter () - t0_fn
980
+ self .maker .mode .fn_time += dt_fn
985
981
profile .vm_call_time += dt_fn
986
982
987
983
# Retrieve the values that were computed
988
984
if outputs is None :
989
985
outputs = [x .data for x in self .output_storage ]
990
- assert len (outputs ) == len (self .output_storage )
991
986
992
987
# Remove internal references to required inputs.
993
988
# These cannot be re-used anyway.
994
- for c in self . input_storage :
995
- if c .required :
996
- c .storage [0 ] = None
989
+ for arg_container in input_storage :
990
+ if arg_container .required :
991
+ arg_container .storage [0 ] = None
997
992
998
993
# if we are allowing garbage collection, remove the
999
994
# output reference from the internal storage cells
1000
995
if getattr (self .vm , "allow_gc" , False ):
1001
- assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
1002
996
for o_container , o_variable in zip (
1003
997
self .output_storage , self .maker .fgraph .outputs
1004
998
):
@@ -1007,37 +1001,31 @@ def restore_defaults():
1007
1001
# WARNING: This circumvents the 'readonly' attribute in x
1008
1002
o_container .storage [0 ] = None
1009
1003
1010
- # TODO: Get rid of this and `expanded_inputs`, since all the VMs now
1011
- # perform the updates themselves
1012
1004
if getattr (self .vm , "need_update_inputs" , True ):
1013
1005
# Update the inputs that have an update function
1014
1006
for input , storage in reversed (
1015
- list (zip (self .maker .expanded_inputs , self . input_storage ))
1007
+ list (zip (self .maker .expanded_inputs , input_storage ))
1016
1008
):
1017
1009
if input .update is not None :
1018
1010
storage .data = outputs .pop ()
1019
1011
else :
1020
1012
outputs = outputs [: self .n_returned_outputs ]
1021
1013
1022
1014
# Put default values back in the storage
1023
- restore_defaults ()
1024
- #
1025
- # NOTE: This logic needs to be replicated in
1026
- # scan.
1027
- # grep for 'PROFILE_CODE'
1028
- #
1029
-
1030
- dt_call = time .perf_counter () - t0
1031
- pytensor .compile .profiling .total_fct_exec_time += dt_call
1032
- self .maker .mode .call_time += dt_call
1015
+ self ._restore_defaults ()
1016
+
1033
1017
if profile :
1018
+ dt_call = time .perf_counter () - t0
1019
+ pytensor .compile .profiling .total_fct_exec_time += dt_call
1020
+ self .maker .mode .call_time += dt_call
1034
1021
profile .fct_callcount += 1
1035
1022
profile .fct_call_time += dt_call
1036
1023
if hasattr (self .vm , "update_profile" ):
1037
1024
self .vm .update_profile (profile )
1038
1025
if profile .ignore_first_call :
1039
1026
profile .reset ()
1040
1027
profile .ignore_first_call = False
1028
+
1041
1029
if self .return_none :
1042
1030
return None
1043
1031
elif self .unpack_single and len (outputs ) == 1 and output_subset is None :
0 commit comments