Skip to content

Commit e00abf3

Browse files
committed
Inverse need not be a symbolic input in PermuteRowElements
1 parent 3cdcfde commit e00abf3

File tree

3 files changed

+22
-28
lines changed

3 files changed

+22
-28
lines changed

pytensor/tensor/basic.py

+16-21
Original file line numberDiff line numberDiff line change
@@ -3481,20 +3481,18 @@ class PermuteRowElements(Op):
34813481
permutation instead.
34823482
"""
34833483

3484-
__props__ = ()
3484+
__props__ = ("inverse",)
3485+
3486+
def __init__(self, inverse: bool):
3487+
super().__init__()
3488+
self.inverse = inverse
34853489

3486-
def make_node(self, x, y, inverse):
3490+
def make_node(self, x, y):
34873491
x = as_tensor_variable(x)
34883492
y = as_tensor_variable(y)
3489-
if inverse: # as_tensor_variable does not accept booleans
3490-
inverse = as_tensor_variable(1)
3491-
else:
3492-
inverse = as_tensor_variable(0)
34933493

34943494
# y should contain integers
34953495
assert y.type.dtype in integer_dtypes
3496-
# Inverse should be an integer scalar
3497-
assert inverse.type.ndim == 0 and inverse.type.dtype in integer_dtypes
34983496

34993497
# Match shapes of x and y
35003498
x_dim = x.type.ndim
@@ -3511,7 +3509,7 @@ def make_node(self, x, y, inverse):
35113509
]
35123510
out_type = tensor(dtype=x.type.dtype, shape=out_shape)
35133511

3514-
inputlist = [x, y, inverse]
3512+
inputlist = [x, y]
35153513
outputlist = [out_type]
35163514
return Apply(self, inputlist, outputlist)
35173515

@@ -3564,7 +3562,7 @@ def _rec_perform(self, node, x, y, inverse, out, curdim):
35643562
raise ValueError(f"Dimension mismatch: {xs0}, {ys0}")
35653563

35663564
def perform(self, node, inp, out):
3567-
x, y, inverse = inp
3565+
x, y = inp
35683566
(outs,) = out
35693567
x_s = x.shape
35703568
y_s = y.shape
@@ -3587,7 +3585,7 @@ def perform(self, node, inp, out):
35873585
if outs[0] is None or outs[0].shape != out_s:
35883586
outs[0] = np.empty(out_s, dtype=x.dtype)
35893587

3590-
self._rec_perform(node, x, y, inverse, outs[0], curdim=0)
3588+
self._rec_perform(node, x, y, self.inverse, outs[0], curdim=0)
35913589

35923590
def infer_shape(self, fgraph, node, in_shapes):
35933591
from pytensor.tensor.math import maximum
@@ -3599,14 +3597,14 @@ def infer_shape(self, fgraph, node, in_shapes):
35993597
return [out_shape]
36003598

36013599
def grad(self, inp, grads):
3602-
from pytensor.tensor.math import Sum, eq
3600+
from pytensor.tensor.math import Sum
36033601

3604-
x, y, inverse = inp
3602+
x, y = inp
36053603
(gz,) = grads
36063604
# First, compute the gradient wrt the broadcasted x.
36073605
# If 'inverse' is False (0), apply the inverse of y on gz.
36083606
# Else, apply y on gz.
3609-
gx = permute_row_elements(gz, y, eq(inverse, 0))
3607+
gx = permute_row_elements(gz, y, not self.inverse)
36103608

36113609
# If x has been broadcasted along some axes, we need to sum
36123610
# the gradient over these axes, but keep the dimension (as
@@ -3643,20 +3641,17 @@ def grad(self, inp, grads):
36433641
if x.type.dtype in discrete_dtypes:
36443642
gx = x.zeros_like()
36453643

3646-
# The elements of y and of inverse both affect the output,
3644+
# The elements of y affect the output,
36473645
# so they are connected to the output,
36483646
# and the transformation isn't defined if their values
36493647
# are non-integer, so the gradient with respect to them is
36503648
# undefined
36513649

3652-
return [gx, grad_undefined(self, 1, y), grad_undefined(self, 1, inverse)]
3653-
3654-
3655-
_permute_row_elements = PermuteRowElements()
3650+
return [gx, grad_undefined(self, 1, y)]
36563651

36573652

3658-
def permute_row_elements(x, y, inverse=0):
3659-
return _permute_row_elements(x, y, inverse)
3653+
def permute_row_elements(x, y, inverse=False):
3654+
return PermuteRowElements(inverse=inverse)(x, y)
36603655

36613656

36623657
def inverse_permutation(perm):

pytensor/tensor/rewriting/subtensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
11471147
val = switch(le(len2, 0), len1 + 1, val)
11481148
val = switch(ge(sl2, len2), len1 + 1, val)
11491149
val = switch(lt(sl2, 0), -len1 - 1, val)
1150-
if sl1.step:
1150+
if sl1.step is not None:
11511151
val = switch(eq(sl1.step, 0), len1 + 1, val)
11521152
return val
11531153
else:

tests/tensor/test_basic.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -3972,21 +3972,20 @@ def test_PermuteRowElements(self):
39723972
advec = dvector()
39733973
aivec = ivector()
39743974

3975-
abool = True
39763975
rng = np.random.default_rng(utt.fetch_seed())
39773976
advec_val = random(5)
39783977
aivec_val = rng.permutation(5).astype("int32")
39793978
self._compile_and_check(
39803979
[advec, aivec],
3981-
[PermuteRowElements()(advec, aivec, abool)],
3980+
[PermuteRowElements(inverse=True)(advec, aivec)],
39823981
[advec_val, aivec_val],
39833982
PermuteRowElements,
39843983
)
39853984

39863985
admat_val = random(3, 5)
39873986
self._compile_and_check(
39883987
[admat, aivec],
3989-
[PermuteRowElements()(admat, aivec, abool)],
3988+
[PermuteRowElements(inverse=False)(admat, aivec)],
39903989
[admat_val, aivec_val],
39913990
PermuteRowElements,
39923991
)
@@ -3995,7 +3994,7 @@ def test_PermuteRowElements(self):
39953994
adtens3_val = random(3, 2, 5)
39963995
self._compile_and_check(
39973996
[adtens3, aivec],
3998-
[PermuteRowElements()(adtens3, aivec, abool)],
3997+
[PermuteRowElements(inverse=True)(adtens3, aivec)],
39993998
[adtens3_val, aivec_val],
40003999
PermuteRowElements,
40014000
)
@@ -4008,7 +4007,7 @@ def test_PermuteRowElements(self):
40084007
admat_val = random(3, 5)
40094008
self._compile_and_check(
40104009
[admat, aimat],
4011-
[PermuteRowElements()(admat, aimat, abool)],
4010+
[PermuteRowElements(inverse=False)(admat, aimat)],
40124011
[admat_val, aimat_val],
40134012
PermuteRowElements,
40144013
)
@@ -4023,7 +4022,7 @@ def test_PermuteRowElements(self):
40234022
aitens3_val[1, ::, ::] = bimat_val
40244023
self._compile_and_check(
40254024
[admat, aitens3],
4026-
[PermuteRowElements()(admat, aitens3, abool)],
4025+
[PermuteRowElements(inverse=True)(admat, aitens3)],
40274026
[admat_val, aitens3_val],
40284027
PermuteRowElements,
40294028
)

0 commit comments

Comments
 (0)