Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MIT-MOT in JAX dispatch of Scan #1232

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 96 additions & 45 deletions pytensor/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ def jax_funcify_Scan(op: Scan, **kwargs):
if info.as_while:
raise NotImplementedError("While Scan cannot yet be converted to JAX")

if info.n_mit_mot:
raise NotImplementedError(
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
)

# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer
rewriter(op.fgraph)
Expand All @@ -29,50 +24,74 @@ def scan(*outer_inputs):
n_steps = outer_inputs[0] # JAX `length`
seqs = op.outer_seqs(outer_inputs) # JAX `xs`

mit_sot_init = []
for tap, seq in zip(
# MIT-MOT and MIT-SOT are provided from outside as a tape long enough to store the initial values and intermediate outputs
# To bootstrap the inner function we need to slice the initial values
mit_mot_inits = []
for taps, seq in zip(
op.info.mit_mot_in_slices, op.outer_mitmot(outer_inputs), strict=True
):
# mit-mot taps are non-negative
init_slice = seq[: max(taps) + 1]
mit_mot_inits.append(init_slice)

mit_sot_inits = []
for taps, seq in zip(
op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs), strict=True
):
init_slice = seq[: abs(min(tap))]
mit_sot_init.append(init_slice)
# mit-sot taps are negative
init_slice = seq[: abs(min(taps))]
mit_sot_inits.append(init_slice)

sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)]
sit_sot_inits = [seq[0] for seq in op.outer_sitsot(outer_inputs)]

init_carry = (
mit_sot_init,
sit_sot_init,
mit_mot_inits,
mit_sot_inits,
sit_sot_inits,
op.outer_shared(outer_inputs),
op.outer_non_seqs(outer_inputs),
) # JAX `init`

def jax_args_to_inner_func_args(carry, x):
"""Convert JAX scan arguments into format expected by scan_inner_func.

scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs)
scan(carry, x) -> scan_inner_func(seqs, mit_mot, mit_sot, sit_sot, shared, non_seqs)
"""

# `carry` contains all inner taps, shared terms, and non_seqs
(
inner_mit_sot,
inner_sit_sot,
inner_shared,
inner_mit_mots,
inner_mit_sots,
inner_sit_sots,
inner_shareds,
inner_non_seqs,
) = carry

# `x` contains the inner sequences
inner_seqs = x

mit_sot_flatten = []
for array, index in zip(
inner_mit_sot, op.info.mit_sot_in_slices, strict=True
# MIT-MOT and MIT-SOT are provided as unified tensors and should be split
# into distinct entries for the inner function
split_mit_mots = []
for taps, seq in zip(
op.info.mit_mot_in_slices, inner_mit_mots, strict=True
):
for tap in taps:
split_mit_mots.append(seq[tap])

split_mit_sots = []
for taps, seq in zip(
op.info.mit_sot_in_slices, inner_mit_sots, strict=True
):
mit_sot_flatten.extend(array[jnp.array(index)])
for tap in taps:
split_mit_sots.append(seq[tap])

inner_scan_inputs = [
*inner_seqs,
*mit_sot_flatten,
*inner_sit_sot,
*inner_shared,
*split_mit_mots, # TODO: Confirm oreding
*split_mit_sots,
*inner_sit_sots,
*inner_shareds,
*inner_non_seqs,
]

Expand All @@ -84,44 +103,71 @@ def inner_func_outs_to_jax_outs(
):
"""Convert inner_scan_func outputs into format expected by JAX scan.

old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
old_carry + (mit_mot_outs, mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
"""
(
inner_mit_sot,
inner_sit_sot,
inner_shared,
inner_mit_mots,
inner_mit_sots,
inner_sit_sots,
inner_shareds,
inner_non_seqs,
) = old_carry

inner_mit_mot_outs = op.inner_mitmot_outs(inner_scan_outs)
inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs)
inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs)
inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs)
inner_shared_outs = op.inner_shared_outs(inner_scan_outs)

# Replace the oldest mit_sot tap by the newest value
inner_mit_sot_new = [
jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0)
for old_mit_sot, new_val in zip(
inner_mit_sot, inner_mit_sot_outs, strict=True
# Group split mit_mot_outs into the respective groups
start = 0
grouped_inner_mit_mot_outs = []
for mit_mot_out_slice in op.info.mit_mot_out_slices:
end = start + len(mit_mot_out_slice)
elements = inner_mit_mot_outs[start:end]
group = jnp.concatenate([e[None] for e in elements], axis=0)
grouped_inner_mit_mot_outs.append(group)
start = end

# Replace the oldest mit-mot taps (last entries) and prepend the newest values
new_inner_mit_mots = []
for old_mit_mot, new_outs in zip(
inner_mit_mots, grouped_inner_mit_mot_outs, strict=True
):
n_outs = len(new_outs)
inner_mit_mot_new = jnp.concatenate(
[old_mit_mot[n_outs:], group], axis=0
)
]
new_inner_mit_mots.append(inner_mit_mot_new)

# Drop the oldest mit-sot tap (first entry) and append the newest value at end
new_inner_mit_sots = []
for old_mit_sot, new_out in zip(
inner_mit_sots, inner_mit_sot_outs, strict=True
):
inner_mit_sot_new = jnp.concatenate(
[old_mit_sot[1:], new_out[None, ...]], axis=0
)
new_inner_mit_mots.append(inner_mit_sot_new)

# Nothing needs to be done with sit_sot
inner_sit_sot_new = inner_sit_sot_outs
new_inner_sit_sots = inner_sit_sot_outs

inner_shared_new = inner_shared
new_inner_shareds = inner_shareds
# Replace old shared inputs by new shared outputs
inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs
new_inner_shareds[: len(inner_shared_outs)] = inner_shared_outs

new_carry = (
inner_mit_sot_new,
inner_sit_sot_new,
inner_shared_new,
new_inner_mit_mots,
new_inner_mit_sots,
new_inner_sit_sots,
new_inner_shareds,
inner_non_seqs,
)

# Shared variables and non_seqs are not traced
traced_outs = [
*grouped_inner_mit_mot_outs,
*inner_mit_sot_outs,
*inner_sit_sot_outs,
*inner_nit_sot_outs,
Expand All @@ -148,9 +194,15 @@ def get_partial_traces(traces):
2. Slice final traces if Scan was instructed to only keep a portion
"""

init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot
init_states = (
mit_mot_inits
+ mit_sot_inits
+ sit_sot_inits
+ [None] * op.info.n_nit_sot
)
buffers = (
op.outer_mitsot(outer_inputs)
op.outer_mitmot(outer_inputs)
+ op.outer_mitsot(outer_inputs)
+ op.outer_sitsot(outer_inputs)
+ op.outer_nitsot(outer_inputs)
)
Expand All @@ -159,11 +211,10 @@ def get_partial_traces(traces):
init_states, traces, buffers, strict=True
):
if init_state is not None:
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
# MIT-MOT, MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
trace = jnp.atleast_1d(trace)
init_state = jnp.expand_dims(
init_state, range(trace.ndim - init_state.ndim)
)
init_state = jnp.expand_dims(init_state, 1)
# TODO: delete this, shouldn't be needed?
full_trace = jnp.concatenate([init_state, trace], axis=0)
buffer_size = buffer.shape[0]
else:
Expand Down
9 changes: 5 additions & 4 deletions tests/link/jax/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,17 @@ def test_scan_nit_sot(view):
assert len(scan_nodes) == 1


@pytest.mark.xfail(raises=NotImplementedError)
def test_scan_mit_mot():
xs = pt.vector("xs", shape=(10,))
xs = pt.tensor("xs", shape=(2, 2))
ys, _ = scan(
lambda xtm2, xtm1: (xtm2 + xtm1),
outputs_info=[{"initial": xs, "taps": [-2, -1]}],
n_steps=10,
n_steps=4,
)
grads_wrt_xs = pt.grad(ys.sum(), wrt=xs)
compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)])
f = function([xs], grads_wrt_xs, mode="JAX")
f(np.arange(4).reshape((2, 2)))
# compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(2)])


def test_scan_update():
Expand Down
Loading