Skip to content

Commit 23b4fe9

Browse files
committed
Stop checking for input alias in Function.__call__
1 parent 4258475 commit 23b4fe9

File tree

3 files changed

+23
-129
lines changed

3 files changed

+23
-129
lines changed

Diff for: pytensor/compile/function/types.py

+12-69
Original file line numberDiff line numberDiff line change
@@ -393,41 +393,6 @@ def __init__(
393393
assert len(self.input_storage) == len(self.maker.fgraph.inputs)
394394
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
395395

396-
# Group indexes of inputs that are potentially aliased to each other
397-
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
398-
# even though there could be two distinct types that use the same kinds of underlying objects.
399-
potential_aliased_input_groups = []
400-
for inp in maker.inputs:
401-
# If the input is a shared variable, the memory region is under PyTensor control
402-
# and can't be aliased.
403-
if not (
404-
isinstance(inp, In)
405-
and inp.borrow
406-
and not inp.shared
407-
and hasattr(inp.variable.type, "may_share_memory")
408-
):
409-
continue
410-
411-
for group in potential_aliased_input_groups:
412-
# If one is super of the other, that means one could be replaced by the other
413-
if any(
414-
inp.variable.type.is_super(other_inp.variable.type)
415-
or other_inp.variable.type.is_super(inp.variable.type)
416-
for other_inp in group
417-
):
418-
group.append(inp)
419-
break
420-
else: # no break
421-
# Input makes a new group
422-
potential_aliased_input_groups.append([inp])
423-
424-
# Potential aliased inputs are those that belong to the same group
425-
self._potential_aliased_input_groups: tuple[tuple[int, ...], ...] = tuple(
426-
tuple(maker.inputs.index(inp) for inp in group)
427-
for group in potential_aliased_input_groups
428-
if len(group) > 1
429-
)
430-
431396
# We will be popping stuff off this `containers` object. It is a copy.
432397
containers = list(self.input_storage)
433398
finder = {}
@@ -844,11 +809,18 @@ def __call__(self, *args, **kwargs):
844809
if self.output_keys is not None:
845810
output_subset = [self.output_keys.index(key) for key in output_subset]
846811

847-
# Reinitialize each container's 'provided' counter
848812
if self.trust_input:
813+
# Set positional arguments
849814
for arg_container, arg in zip(input_storage, args, strict=False):
850815
arg_container.storage[0] = arg
816+
817+
# Set keyword arguments
818+
if kwargs: # for speed, skip the items for empty kwargs
819+
for k, arg in kwargs.items():
820+
self[k] = arg
821+
851822
else:
823+
# Reinitialize each container's 'provided' counter
852824
for arg_container in input_storage:
853825
arg_container.provided = 0
854826

@@ -899,39 +871,10 @@ def __call__(self, *args, **kwargs):
899871
raise
900872
arg_container.provided += 1
901873

902-
# Set keyword arguments
903-
if kwargs: # for speed, skip the items for empty kwargs
904-
for k, arg in kwargs.items():
905-
self[k] = arg
906-
907-
if not self.trust_input:
908-
# Collect aliased inputs among the storage space
909-
for potential_group in self._potential_aliased_input_groups:
910-
args_share_memory: list[list[int]] = []
911-
for i in potential_group:
912-
i_type = self.maker.inputs[i].variable.type
913-
i_val = input_storage[i].storage[0]
914-
915-
# Check if value is aliased with any of the values in one of the groups
916-
for j_group in args_share_memory:
917-
if any(
918-
i_type.may_share_memory(input_storage[j].storage[0], i_val)
919-
for j in j_group
920-
):
921-
j_group.append(i)
922-
break
923-
else: # no break
924-
# Create a new group
925-
args_share_memory.append([i])
926-
927-
# Check for groups of more than one argument that share memory
928-
for group in args_share_memory:
929-
if len(group) > 1:
930-
# copy all but the first
931-
for i in group[1:]:
932-
input_storage[i].storage[0] = copy.copy(
933-
input_storage[i].storage[0]
934-
)
874+
# Set keyword arguments
875+
if kwargs: # for speed, skip the items for empty kwargs
876+
for k, arg in kwargs.items():
877+
self[k] = arg
935878

936879
# Check if inputs are missing, or if inputs were set more than once, or
937880
# if we tried to provide inputs that are supposed to be implicit.

Diff for: tests/compile/function/test_pfunc.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -732,16 +732,13 @@ class TestAliasingRules:
732732
# 2. shared variables are allocated in this memory space, as are the
733733
# temporaries used for Function evaluation.
734734
#
735-
# 3. Physically, this managed memory space may be spread across the host,
736-
# on a GPU device(s), or even on a remote machine.
737-
#
738-
# 4. PyTensor assumes that shared variables are never aliased to one another,
735+
# 3. PyTensor assumes that shared variables are never aliased to one another,
739736
# and tries to make it impossible to accidentally alias them.
740737
#
741-
# 5. PyTensor's managed data is constant while PyTensor Functions are not running
738+
# 4. PyTensor's managed data is constant while PyTensor Functions are not running
742739
# and pytensor library code is not running.
743740
#
744-
# 6. The default behaviour of Function is to return user-space values for
741+
# 5. The default behaviour of Function is to return user-space values for
745742
# outputs, but this can be overridden (borrow=True) for better performance,
746743
# in which case the returned value may be aliased to managed memory, and
747744
# potentially invalidated by the next PyTensor Function call or call to pytensor
@@ -810,6 +807,9 @@ def test_sparse_input_aliasing_affecting_inplace_operations(self):
810807
assert np.allclose(vals.todense(), bogus_vals.todense())
811808

812809
def test_input_aliasing_affecting_inplace_operations(self):
810+
# Note: The input aliasing check was disabled, so this test now just confirms that wrong values
811+
# will be obtained if the inputs are aliased.
812+
813813
# Note: to trigger this bug with pytensor rev 4586:2bc6fc7f218b,
814814
# you need to make in inputs mutable (so that inplace
815815
# operations are used) and to break the elemwise composition
@@ -860,9 +860,12 @@ def test_input_aliasing_affecting_inplace_operations(self):
860860
v_copy = v.copy()
861861
vals = f(v, v_copy, m, m_copy)
862862

863-
assert np.allclose(vals, bogus_vals)
863+
assert not np.allclose(vals, bogus_vals)
864864

865865
def test_partial_input_aliasing_affecting_inplace_operations(self):
866+
# Note: The input aliasing check was disabled, so this test now just confirms that wrong values
867+
# will be obtained if the inputs are aliased.
868+
866869
# Note: to trigger this bug with pytensor rev 4586:2bc6fc7f218b,
867870
# you need to make in inputs mutable ( so that inplace
868871
# operations are used) and to break the elemwise composition
@@ -906,7 +909,7 @@ def test_partial_input_aliasing_affecting_inplace_operations(self):
906909
v_copy2 = v.copy()
907910
vals = f(v[:2], v_copy1[1:3], v_copy2[2:4], m, m_copy1, m_copy2)
908911

909-
assert np.allclose(vals, bogus_vals)
912+
assert not np.allclose(vals, bogus_vals)
910913

911914
def test_potential_output_aliasing_induced_by_updates(self):
912915
A = self.shared(np.zeros((2, 2)))

Diff for: tests/compile/function/test_types.py

-52
Original file line numberDiff line numberDiff line change
@@ -752,52 +752,6 @@ def test_default_values(self):
752752
except TypeError:
753753
assert funct(first=1) == x
754754

755-
def test_check_for_aliased_inputs(self):
756-
b = np.random.random((5, 4))
757-
s1 = shared(b)
758-
s2 = shared(b)
759-
x1 = vector()
760-
x2 = vector(shape=(3,))
761-
x3 = vector(shape=(1,))
762-
763-
# Assert cases we should not check for aliased inputs
764-
for d in [
765-
dict(outputs=[s1 + 1]),
766-
dict(outputs=[s1 + 1, s2 + 3]),
767-
dict(outputs=[s1 + 1], updates=[(s2, s2 + 3)]),
768-
dict(inputs=[x1], outputs=[x1 + 1], updates=[(s2, s2 + 3)]),
769-
dict(
770-
inputs=[In(x1, mutable=True)], outputs=[x1 + 1], updates=[(s2, s2 + 3)]
771-
),
772-
dict(
773-
inputs=[In(x2, mutable=True), In(x3, mutable=True)],
774-
outputs=[x2 + 2, x3 + 3],
775-
),
776-
]:
777-
if "inputs" not in d:
778-
d["inputs"] = []
779-
f = function(**d)
780-
assert not f._potential_aliased_input_groups, d
781-
782-
# Assert cases we should check for aliased inputs
783-
for d in [
784-
dict(
785-
inputs=[In(x1, mutable=True), In(x2, mutable=True)],
786-
outputs=[x1 + 1, x2 + 2],
787-
updates=[(s2, s2 + 3)],
788-
),
789-
dict(
790-
inputs=[In(x1, mutable=True), In(x3, mutable=True)],
791-
outputs=[x1 + 1, x3 + 3],
792-
updates=[(s2, s2 + 3)],
793-
),
794-
]:
795-
if "inputs" not in d:
796-
d["inputs"] = []
797-
f = function(**d)
798-
799-
assert f._potential_aliased_input_groups, d
800-
801755
def test_output_dictionary(self):
802756
# Tests that function works when outputs is a dictionary
803757

@@ -939,12 +893,6 @@ def test_deepcopy(self):
939893
assert x not in g.container
940894
assert x not in g.value
941895
assert len(f.defaults) == len(g.defaults)
942-
# Shared variable is the first input
943-
assert (
944-
f._potential_aliased_input_groups
945-
== g._potential_aliased_input_groups
946-
== ((1, 2),)
947-
)
948896
assert f.name == g.name
949897
assert f.maker.fgraph.name == g.maker.fgraph.name
950898
# print(f"{f.defaults = }")

0 commit comments

Comments
 (0)