@@ -3481,20 +3481,18 @@ class PermuteRowElements(Op):
3481
3481
permutation instead.
3482
3482
"""
3483
3483
3484
- __props__ = ()
3484
+ __props__ = ("inverse" ,)
3485
+
3486
+ def __init__ (self , inverse : bool ):
3487
+ super ().__init__ ()
3488
+ self .inverse = inverse
3485
3489
3486
- def make_node (self , x , y , inverse ):
3490
+ def make_node (self , x , y ):
3487
3491
x = as_tensor_variable (x )
3488
3492
y = as_tensor_variable (y )
3489
- if inverse : # as_tensor_variable does not accept booleans
3490
- inverse = as_tensor_variable (1 )
3491
- else :
3492
- inverse = as_tensor_variable (0 )
3493
3493
3494
3494
# y should contain integers
3495
3495
assert y .type .dtype in integer_dtypes
3496
- # Inverse should be an integer scalar
3497
- assert inverse .type .ndim == 0 and inverse .type .dtype in integer_dtypes
3498
3496
3499
3497
# Match shapes of x and y
3500
3498
x_dim = x .type .ndim
@@ -3511,7 +3509,7 @@ def make_node(self, x, y, inverse):
3511
3509
]
3512
3510
out_type = tensor (dtype = x .type .dtype , shape = out_shape )
3513
3511
3514
- inputlist = [x , y , inverse ]
3512
+ inputlist = [x , y ]
3515
3513
outputlist = [out_type ]
3516
3514
return Apply (self , inputlist , outputlist )
3517
3515
@@ -3564,7 +3562,7 @@ def _rec_perform(self, node, x, y, inverse, out, curdim):
3564
3562
raise ValueError (f"Dimension mismatch: { xs0 } , { ys0 } " )
3565
3563
3566
3564
def perform (self , node , inp , out ):
3567
- x , y , inverse = inp
3565
+ x , y = inp
3568
3566
(outs ,) = out
3569
3567
x_s = x .shape
3570
3568
y_s = y .shape
@@ -3587,7 +3585,7 @@ def perform(self, node, inp, out):
3587
3585
if outs [0 ] is None or outs [0 ].shape != out_s :
3588
3586
outs [0 ] = np .empty (out_s , dtype = x .dtype )
3589
3587
3590
- self ._rec_perform (node , x , y , inverse , outs [0 ], curdim = 0 )
3588
+ self ._rec_perform (node , x , y , self . inverse , outs [0 ], curdim = 0 )
3591
3589
3592
3590
def infer_shape (self , fgraph , node , in_shapes ):
3593
3591
from pytensor .tensor .math import maximum
@@ -3599,14 +3597,14 @@ def infer_shape(self, fgraph, node, in_shapes):
3599
3597
return [out_shape ]
3600
3598
3601
3599
def grad (self , inp , grads ):
3602
- from pytensor .tensor .math import Sum , eq
3600
+ from pytensor .tensor .math import Sum
3603
3601
3604
- x , y , inverse = inp
3602
+ x , y = inp
3605
3603
(gz ,) = grads
3606
3604
# First, compute the gradient wrt the broadcasted x.
3607
3605
# If 'inverse' is False (0), apply the inverse of y on gz.
3608
3606
# Else, apply y on gz.
3609
- gx = permute_row_elements (gz , y , eq ( inverse , 0 ) )
3607
+ gx = permute_row_elements (gz , y , not self . inverse )
3610
3608
3611
3609
# If x has been broadcasted along some axes, we need to sum
3612
3610
# the gradient over these axes, but keep the dimension (as
@@ -3643,20 +3641,17 @@ def grad(self, inp, grads):
3643
3641
if x .type .dtype in discrete_dtypes :
3644
3642
gx = x .zeros_like ()
3645
3643
3646
- # The elements of y and of inverse both affect the output,
3644
+ # The elements of y affect the output,
3647
3645
# so they are connected to the output,
3648
3646
# and the transformation isn't defined if their values
3649
3647
# are non-integer, so the gradient with respect to them is
3650
3648
# undefined
3651
3649
3652
- return [gx , grad_undefined (self , 1 , y ), grad_undefined (self , 1 , inverse )]
3653
-
3654
-
3655
- _permute_row_elements = PermuteRowElements ()
3650
+ return [gx , grad_undefined (self , 1 , y )]
3656
3651
3657
3652
3658
- def permute_row_elements (x , y , inverse = 0 ):
3659
- return _permute_row_elements ( x , y , inverse )
3653
+ def permute_row_elements (x , y , inverse = False ):
3654
+ return PermuteRowElements ( inverse = inverse )( x , y )
3660
3655
3661
3656
3662
3657
def inverse_permutation (perm ):
0 commit comments