Skip to content

Commit d77f26c

Browse files
committed
Cleanup Function.__call__
1 parent 147c892 commit d77f26c

File tree

7 files changed

+136
-161
lines changed

7 files changed

+136
-161
lines changed

Diff for: pytensor/compile/function/types.py

+113-125
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ class Function:
326326
def __init__(
327327
self,
328328
vm: "VM",
329-
input_storage,
330-
output_storage,
329+
input_storage: list[Container],
330+
output_storage: list[Container],
331331
indices,
332332
outputs,
333333
defaults,
@@ -372,7 +372,6 @@ def __init__(
372372
name
373373
A string name.
374374
"""
375-
# TODO: Rename to `vm`
376375
self.vm = vm
377376
self.input_storage = input_storage
378377
self.output_storage = output_storage
@@ -388,31 +387,49 @@ def __init__(
388387
self.nodes_with_inner_function = []
389388
self.output_keys = output_keys
390389

391-
# See if we have any mutable / borrow inputs
392-
# TODO: this only need to be set if there is more than one input
393-
self._check_for_aliased_inputs = False
394-
for i in maker.inputs:
395-
# If the input is a shared variable, the memory region is
396-
# under PyTensor control and so we don't need to check if it
397-
# is aliased as we never do that.
398-
if (
399-
isinstance(i, In)
400-
and not i.shared
401-
and (getattr(i, "borrow", False) or getattr(i, "mutable", False))
390+
assert len(self.input_storage) == len(self.maker.fgraph.inputs)
391+
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
392+
393+
# Group indexes of inputs that are potentially aliased to each other
394+
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
395+
# even though there could be two distinct types that use the same kinds of underlying objects.
396+
potential_aliased_input_groups = []
397+
for inp in maker.inputs:
398+
# If the input is a shared variable, the memory region is under PyTensor control
399+
# and can't be aliased.
400+
if not (
401+
isinstance(inp, In)
402+
and inp.borrow
403+
and not inp.shared
404+
and hasattr(inp.variable.type, "may_share_memory")
402405
):
403-
self._check_for_aliased_inputs = True
404-
break
406+
continue
407+
408+
for group in potential_aliased_input_groups:
409+
# If one is super of the other, that means one could be replaced by the other
410+
if any(
411+
inp.variable.type.is_super(other_inp.variable.type)
412+
or other_inp.variable.type.is_super(inp.variable.type)
413+
for other_inp in group
414+
):
415+
group.append(inp)
416+
break
417+
else: # no break
418+
# Input makes a new group
419+
potential_aliased_input_groups.append([inp])
420+
421+
# Potential aliased inputs are those that belong to the same group
422+
self._potential_aliased_input_groups: tuple[tuple[int, ...], ...] = tuple(
423+
tuple(maker.inputs.index(inp) for inp in group)
424+
for group in potential_aliased_input_groups
425+
if len(group) > 1
426+
)
405427

406428
# We will be popping stuff off this `containers` object. It is a copy.
407429
containers = list(self.input_storage)
408430
finder = {}
409431
inv_finder = {}
410432

411-
def distribute(indices, cs, value):
412-
input.distribute(value, indices, cs)
413-
for c in cs:
414-
c.provided += 1
415-
416433
# Store the list of names of named inputs.
417434
named_inputs = []
418435
# Count the number of un-named inputs.
@@ -777,6 +794,13 @@ def checkSV(sv_ori, sv_rpl):
777794
f_cpy.maker.fgraph.name = name
778795
return f_cpy
779796

797+
def _restore_defaults(self):
798+
for i, (required, refeed, value) in enumerate(self.defaults):
799+
if refeed:
800+
if isinstance(value, Container):
801+
value = value.storage[0]
802+
self[i] = value
803+
780804
def __call__(self, *args, **kwargs):
781805
"""
782806
Evaluates value of a function on given arguments.
@@ -805,52 +829,43 @@ def __call__(self, *args, **kwargs):
805829
List of outputs on indices/keys from ``output_subset`` or all of them,
806830
if ``output_subset`` is not passed.
807831
"""
808-
809-
def restore_defaults():
810-
for i, (required, refeed, value) in enumerate(self.defaults):
811-
if refeed:
812-
if isinstance(value, Container):
813-
value = value.storage[0]
814-
self[i] = value
815-
832+
input_storage = self.input_storage
816833
profile = self.profile
817-
t0 = time.perf_counter()
834+
835+
if profile:
836+
t0 = time.perf_counter()
818837

819838
output_subset = kwargs.pop("output_subset", None)
820839
if output_subset is not None and self.output_keys is not None:
821840
output_subset = [self.output_keys.index(key) for key in output_subset]
822841

823842
# Reinitialize each container's 'provided' counter
824843
if self.trust_input:
825-
i = 0
826-
for arg in args:
827-
s = self.input_storage[i]
828-
s.storage[0] = arg
829-
i += 1
844+
for arg_container, arg in zip(input_storage, args, strict=False):
845+
arg_container.storage[0] = arg
830846
else:
831-
for c in self.input_storage:
832-
c.provided = 0
847+
for arg_container in input_storage:
848+
arg_container.provided = 0
833849

834-
if len(args) + len(kwargs) > len(self.input_storage):
850+
if len(args) + len(kwargs) > len(input_storage):
835851
raise TypeError("Too many parameter passed to pytensor function")
836852

837853
# Set positional arguments
838-
i = 0
839-
for arg in args:
840-
# TODO: provide a option for skipping the filter if we really
841-
# want speed.
842-
s = self.input_storage[i]
843-
# see this emails for a discuation about None as input
854+
for arg_container, arg in zip(input_storage, args, strict=False):
855+
# See discussion about None as input
844856
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
845857
if arg is None:
846-
s.storage[0] = arg
858+
arg_container.storage[0] = arg
847859
else:
848860
try:
849-
s.storage[0] = s.type.filter(
850-
arg, strict=s.strict, allow_downcast=s.allow_downcast
861+
arg_container.storage[0] = arg_container.type.filter(
862+
arg,
863+
strict=arg_container.strict,
864+
allow_downcast=arg_container.allow_downcast,
851865
)
852866

853867
except Exception as e:
868+
i = input_storage.index(arg_container)
854869
function_name = "pytensor function"
855870
argument_name = "argument"
856871
if self.name:
@@ -875,93 +890,74 @@ def restore_defaults():
875890
+ function_name
876891
+ f" at index {int(i)} (0-based). {where}"
877892
) + e.args
878-
restore_defaults()
893+
self._restore_defaults()
879894
raise
880-
s.provided += 1
881-
i += 1
895+
arg_container.provided += 1
882896

883897
# Set keyword arguments
884898
if kwargs: # for speed, skip the items for empty kwargs
885899
for k, arg in kwargs.items():
886900
self[k] = arg
887901

888-
if (
889-
not self.trust_input
890-
and
891-
# The getattr is only needed for old pickle
892-
getattr(self, "_check_for_aliased_inputs", True)
893-
):
902+
if not self.trust_input:
894903
# Collect aliased inputs among the storage space
895-
args_share_memory = []
896-
for i in range(len(self.input_storage)):
897-
i_var = self.maker.inputs[i].variable
898-
i_val = self.input_storage[i].storage[0]
899-
if hasattr(i_var.type, "may_share_memory"):
900-
is_aliased = False
901-
for j in range(len(args_share_memory)):
902-
group_j = zip(
903-
[
904-
self.maker.inputs[k].variable
905-
for k in args_share_memory[j]
906-
],
907-
[
908-
self.input_storage[k].storage[0]
909-
for k in args_share_memory[j]
910-
],
911-
)
904+
for potential_group in self._potential_aliased_input_groups:
905+
args_share_memory: list[list[int]] = []
906+
for i in potential_group:
907+
i_type = self.maker.inputs[i].variable.type
908+
i_val = input_storage[i].storage[0]
909+
910+
# Check if value is aliased with any of the values in one of the groups
911+
for j_group in args_share_memory:
912912
if any(
913-
(
914-
var.type is i_var.type
915-
and var.type.may_share_memory(val, i_val)
916-
)
917-
for (var, val) in group_j
913+
i_type.may_share_memory(input_storage[j].storage[0], i_val)
914+
for j in j_group
918915
):
919-
is_aliased = True
920-
args_share_memory[j].append(i)
916+
j_group.append(i)
921917
break
922-
923-
if not is_aliased:
918+
else: # no break
919+
# Create a new group
924920
args_share_memory.append([i])
925921

926-
# Check for groups of more than one argument that share memory
927-
for group in args_share_memory:
928-
if len(group) > 1:
929-
# copy all but the first
930-
for j in group[1:]:
931-
self.input_storage[j].storage[0] = copy.copy(
932-
self.input_storage[j].storage[0]
933-
)
922+
# Check for groups of more than one argument that share memory
923+
for group in args_share_memory:
924+
if len(group) > 1:
925+
# copy all but the first
926+
for i in group[1:]:
927+
input_storage[i].storage[0] = copy.copy(
928+
input_storage[i].storage[0]
929+
)
934930

935-
# Check if inputs are missing, or if inputs were set more than once, or
936-
# if we tried to provide inputs that are supposed to be implicit.
937-
if not self.trust_input:
938-
for c in self.input_storage:
939-
if c.required and not c.provided:
940-
restore_defaults()
931+
# Check if inputs are missing, or if inputs were set more than once, or
932+
# if we tried to provide inputs that are supposed to be implicit.
933+
for arg_container in input_storage:
934+
if arg_container.required and not arg_container.provided:
935+
self._restore_defaults()
941936
raise TypeError(
942-
f"Missing required input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
937+
f"Missing required input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
943938
)
944-
if c.provided > 1:
945-
restore_defaults()
939+
if arg_container.provided > 1:
940+
self._restore_defaults()
946941
raise TypeError(
947-
f"Multiple values for input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
942+
f"Multiple values for input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
948943
)
949-
if c.implicit and c.provided > 0:
950-
restore_defaults()
944+
if arg_container.implicit and arg_container.provided > 0:
945+
self._restore_defaults()
951946
raise TypeError(
952-
f"Tried to provide value for implicit input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
947+
f"Tried to provide value for implicit input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
953948
)
954949

955950
# Do the actual work
956-
t0_fn = time.perf_counter()
951+
if profile:
952+
t0_fn = time.perf_counter()
957953
try:
958954
outputs = (
959955
self.vm()
960956
if output_subset is None
961957
else self.vm(output_subset=output_subset)
962958
)
963959
except Exception:
964-
restore_defaults()
960+
self._restore_defaults()
965961
if hasattr(self.vm, "position_of_error"):
966962
# this is a new vm-provided function or c linker
967963
# they need this because the exception manipulation
@@ -979,26 +975,24 @@ def restore_defaults():
979975
# old-style linkers raise their own exceptions
980976
raise
981977

982-
dt_fn = time.perf_counter() - t0_fn
983-
self.maker.mode.fn_time += dt_fn
984978
if profile:
979+
dt_fn = time.perf_counter() - t0_fn
980+
self.maker.mode.fn_time += dt_fn
985981
profile.vm_call_time += dt_fn
986982

987983
# Retrieve the values that were computed
988984
if outputs is None:
989985
outputs = [x.data for x in self.output_storage]
990-
assert len(outputs) == len(self.output_storage)
991986

992987
# Remove internal references to required inputs.
993988
# These cannot be re-used anyway.
994-
for c in self.input_storage:
995-
if c.required:
996-
c.storage[0] = None
989+
for arg_container in input_storage:
990+
if arg_container.required:
991+
arg_container.storage[0] = None
997992

998993
# if we are allowing garbage collection, remove the
999994
# output reference from the internal storage cells
1000995
if getattr(self.vm, "allow_gc", False):
1001-
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
1002996
for o_container, o_variable in zip(
1003997
self.output_storage, self.maker.fgraph.outputs
1004998
):
@@ -1007,37 +1001,31 @@ def restore_defaults():
10071001
# WARNING: This circumvents the 'readonly' attribute in x
10081002
o_container.storage[0] = None
10091003

1010-
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
1011-
# perform the updates themselves
10121004
if getattr(self.vm, "need_update_inputs", True):
10131005
# Update the inputs that have an update function
10141006
for input, storage in reversed(
1015-
list(zip(self.maker.expanded_inputs, self.input_storage))
1007+
list(zip(self.maker.expanded_inputs, input_storage))
10161008
):
10171009
if input.update is not None:
10181010
storage.data = outputs.pop()
10191011
else:
10201012
outputs = outputs[: self.n_returned_outputs]
10211013

10221014
# Put default values back in the storage
1023-
restore_defaults()
1024-
#
1025-
# NOTE: This logic needs to be replicated in
1026-
# scan.
1027-
# grep for 'PROFILE_CODE'
1028-
#
1029-
1030-
dt_call = time.perf_counter() - t0
1031-
pytensor.compile.profiling.total_fct_exec_time += dt_call
1032-
self.maker.mode.call_time += dt_call
1015+
self._restore_defaults()
1016+
10331017
if profile:
1018+
dt_call = time.perf_counter() - t0
1019+
pytensor.compile.profiling.total_fct_exec_time += dt_call
1020+
self.maker.mode.call_time += dt_call
10341021
profile.fct_callcount += 1
10351022
profile.fct_call_time += dt_call
10361023
if hasattr(self.vm, "update_profile"):
10371024
self.vm.update_profile(profile)
10381025
if profile.ignore_first_call:
10391026
profile.reset()
10401027
profile.ignore_first_call = False
1028+
10411029
if self.return_none:
10421030
return None
10431031
elif self.unpack_single and len(outputs) == 1 and output_subset is None:

Diff for: pytensor/gradient.py

-3
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,6 @@ def fiter_variable(self, other):
128128
" a symbolic placeholder."
129129
)
130130

131-
def may_share_memory(a, b):
132-
return False
133-
134131
def value_eq(a, b, force_same_dtype=True):
135132
raise AssertionError(
136133
"If you're assigning to a DisconnectedType you're"

0 commit comments

Comments
 (0)