Skip to content

Commit 1acacb8

Browse files
committed
Remove Join view flag
Do not normalize constant axis in make_node and fix rewrite that assumed this would always be positive
1 parent 4e59f21 commit 1acacb8

File tree

8 files changed

+159
-318
lines changed

8 files changed

+159
-318
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,7 @@ def jax_funcify_Join(op, **kwargs):
8787
def join(axis, *tensors):
8888
# tensors could also be tuples, and in this case they don't have a ndim
8989
tensors = [jnp.asarray(tensor) for tensor in tensors]
90-
view = op.view
91-
if (view != -1) and all(
92-
tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :]
93-
):
94-
return tensors[view]
95-
96-
else:
97-
return jnp.concatenate(tensors, axis=axis)
90+
return jnp.concatenate(tensors, axis=axis)
9891

9992
return join
10093

pytensor/link/numba/dispatch/tensor_basic.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,9 @@ def arange(start, stop, step):
117117

118118
@numba_funcify.register(Join)
119119
def numba_funcify_Join(op, **kwargs):
120-
view = op.view
121-
122-
if view != -1:
123-
# TODO: Where (and why) is this `Join.view` even being used? From a
124-
# quick search, the answer appears to be "nowhere", so we should
125-
# probably just remove it.
126-
raise NotImplementedError("The `view` parameter to `Join` is not supported")
127-
128120
@numba_basic.numba_njit
129121
def join(axis, *tensors):
130-
return np.concatenate(tensors, numba_basic.to_scalar(axis))
122+
return np.concatenate(tensors, axis.item())
131123

132124
return join
133125

pytensor/scan/checkpoints.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytensor.tensor.basic as ptb
22
from pytensor.scan.basic import scan
3-
from pytensor.tensor.basic import Join
43
from pytensor.tensor.math import ceil, eq, neq
54
from pytensor.tensor.subtensor import set_subtensor
65

@@ -127,14 +126,12 @@ def scan_checkpoints(
127126

128127
# Pad the sequences if needed
129128
if padding:
130-
# Since padding could be an empty tensor, Join returns a view of s.
131-
join = Join(view=0)
132129
for i, s in enumerate(sequences):
133130
overshoots_by = s.shape[0] % save_every_N
134131
overshoots = neq(overshoots_by, 0)
135132
n = (save_every_N - overshoots_by) * overshoots
136133
z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype)
137-
sequences[i] = join(0, s, z)
134+
sequences[i] = ptb.join(0, s, z)
138135

139136
# Establish the input variables of the outer scan
140137
o_sequences = [

0 commit comments

Comments
 (0)