From 89b0b912309d27b33315124b170fc7c3baada59c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 21 Feb 2025 19:21:40 +0100 Subject: [PATCH] .wip --- pytensor/link/jax/dispatch/scan.py | 141 ++++++++++++++++++++--------- tests/link/jax/test_scan.py | 9 +- 2 files changed, 101 insertions(+), 49 deletions(-) diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index d98328f0cf..2728f53490 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -13,11 +13,6 @@ def jax_funcify_Scan(op: Scan, **kwargs): if info.as_while: raise NotImplementedError("While Scan cannot yet be converted to JAX") - if info.n_mit_mot: - raise NotImplementedError( - "Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX" - ) - # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode) rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer rewriter(op.fgraph) @@ -29,18 +24,30 @@ def scan(*outer_inputs): n_steps = outer_inputs[0] # JAX `length` seqs = op.outer_seqs(outer_inputs) # JAX `xs` - mit_sot_init = [] - for tap, seq in zip( + # MIT-MOT and MIT-SOT are provided from outside as a tape long enough to store the initial values and intermediate outputs + # To bootstrap the inner function we need to slice the initial values + mit_mot_inits = [] + for taps, seq in zip( + op.info.mit_mot_in_slices, op.outer_mitmot(outer_inputs), strict=True + ): + # mit-mot taps are non-negative + init_slice = seq[: max(taps) + 1] + mit_mot_inits.append(init_slice) + + mit_sot_inits = [] + for taps, seq in zip( op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs), strict=True ): - init_slice = seq[: abs(min(tap))] - mit_sot_init.append(init_slice) + # mit-sot taps are negative + init_slice = seq[: abs(min(taps))] + mit_sot_inits.append(init_slice) - sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)] + sit_sot_inits = [seq[0] for seq in op.outer_sitsot(outer_inputs)] init_carry = ( - mit_sot_init, - sit_sot_init, + mit_mot_inits, + mit_sot_inits, + sit_sot_inits, op.outer_shared(outer_inputs), op.outer_non_seqs(outer_inputs), ) # JAX `init` @@ -48,31 +55,43 @@ def scan(*outer_inputs): def jax_args_to_inner_func_args(carry, x): """Convert JAX scan arguments into format expected by scan_inner_func. - scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs) + scan(carry, x) -> scan_inner_func(seqs, mit_mot, mit_sot, sit_sot, shared, non_seqs) """ # `carry` contains all inner taps, shared terms, and non_seqs ( - inner_mit_sot, - inner_sit_sot, - inner_shared, + inner_mit_mots, + inner_mit_sots, + inner_sit_sots, + inner_shareds, inner_non_seqs, ) = carry # `x` contains the inner sequences inner_seqs = x - mit_sot_flatten = [] - for array, index in zip( - inner_mit_sot, op.info.mit_sot_in_slices, strict=True + # MIT-MOT and MIT-SOT are provided as unified tensors and should be split + # into distinct entries for the inner function + split_mit_mots = [] + for taps, seq in zip( + op.info.mit_mot_in_slices, inner_mit_mots, strict=True + ): + for tap in taps: + split_mit_mots.append(seq[tap]) + + split_mit_sots = [] + for taps, seq in zip( + op.info.mit_sot_in_slices, inner_mit_sots, strict=True ): - mit_sot_flatten.extend(array[jnp.array(index)]) + for tap in taps: + split_mit_sots.append(seq[tap]) inner_scan_inputs = [ *inner_seqs, - *mit_sot_flatten, - *inner_sit_sot, - *inner_shared, + *split_mit_mots, # TODO: Confirm oreding + *split_mit_sots, + *inner_sit_sots, + *inner_shareds, *inner_non_seqs, ] @@ -84,44 +103,71 @@ def inner_func_outs_to_jax_outs( ): """Convert inner_scan_func outputs into format expected by JAX scan. - old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys) + old_carry + (mit_mot_outs, mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys) """ ( - inner_mit_sot, - inner_sit_sot, - inner_shared, + inner_mit_mots, + inner_mit_sots, + inner_sit_sots, + inner_shareds, inner_non_seqs, ) = old_carry + inner_mit_mot_outs = op.inner_mitmot_outs(inner_scan_outs) inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs) inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs) inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs) inner_shared_outs = op.inner_shared_outs(inner_scan_outs) - # Replace the oldest mit_sot tap by the newest value - inner_mit_sot_new = [ - jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0) - for old_mit_sot, new_val in zip( - inner_mit_sot, inner_mit_sot_outs, strict=True + # Group split mit_mot_outs into the respective groups + start = 0 + grouped_inner_mit_mot_outs = [] + for mit_mot_out_slice in op.info.mit_mot_out_slices: + end = start + len(mit_mot_out_slice) + elements = inner_mit_mot_outs[start:end] + group = jnp.concatenate([e[None] for e in elements], axis=0) + grouped_inner_mit_mot_outs.append(group) + start = end + + # Replace the oldest mit-mot taps (last entries) and prepend the newest values + new_inner_mit_mots = [] + for old_mit_mot, new_outs in zip( + inner_mit_mots, grouped_inner_mit_mot_outs, strict=True + ): + n_outs = len(new_outs) + inner_mit_mot_new = jnp.concatenate( + [old_mit_mot[n_outs:], group], axis=0 ) - ] + new_inner_mit_mots.append(inner_mit_mot_new) + + # Drop the oldest mit-sot tap (first entry) and append the newest value at end + new_inner_mit_sots = [] + for old_mit_sot, new_out in zip( + inner_mit_sots, inner_mit_sot_outs, strict=True + ): + inner_mit_sot_new = jnp.concatenate( + [old_mit_sot[1:], new_out[None, ...]], axis=0 + ) + new_inner_mit_mots.append(inner_mit_sot_new) # Nothing needs to be done with sit_sot - inner_sit_sot_new = inner_sit_sot_outs + new_inner_sit_sots = inner_sit_sot_outs - inner_shared_new = inner_shared + new_inner_shareds = inner_shareds # Replace old shared inputs by new shared outputs - inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs + new_inner_shareds[: len(inner_shared_outs)] = inner_shared_outs new_carry = ( - inner_mit_sot_new, - inner_sit_sot_new, - inner_shared_new, + new_inner_mit_mots, + new_inner_mit_sots, + new_inner_sit_sots, + new_inner_shareds, inner_non_seqs, ) # Shared variables and non_seqs are not traced traced_outs = [ + *grouped_inner_mit_mot_outs, *inner_mit_sot_outs, *inner_sit_sot_outs, *inner_nit_sot_outs, @@ -148,9 +194,15 @@ def get_partial_traces(traces): 2. Slice final traces if Scan was instructed to only keep a portion """ - init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot + init_states = ( + mit_mot_inits + + mit_sot_inits + + sit_sot_inits + + [None] * op.info.n_nit_sot + ) buffers = ( - op.outer_mitsot(outer_inputs) + op.outer_mitmot(outer_inputs) + + op.outer_mitsot(outer_inputs) + op.outer_sitsot(outer_inputs) + op.outer_nitsot(outer_inputs) ) @@ -159,11 +211,10 @@ def get_partial_traces(traces): init_states, traces, buffers, strict=True ): if init_state is not None: - # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer + # MIT-MOT, MIT-SOT and SIT-SOT: The final output should be as long as the input buffer trace = jnp.atleast_1d(trace) - init_state = jnp.expand_dims( - init_state, range(trace.ndim - init_state.ndim) - ) + init_state = jnp.expand_dims(init_state, 1) + # TODO: delete this, shouldn't be needed? full_trace = jnp.concatenate([init_state, trace], axis=0) buffer_size = buffer.shape[0] else: diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index 4ee95ab527..72eeace382 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -98,16 +98,17 @@ def test_scan_nit_sot(view): assert len(scan_nodes) == 1 -@pytest.mark.xfail(raises=NotImplementedError) def test_scan_mit_mot(): - xs = pt.vector("xs", shape=(10,)) + xs = pt.tensor("xs", shape=(2, 2)) ys, _ = scan( lambda xtm2, xtm1: (xtm2 + xtm1), outputs_info=[{"initial": xs, "taps": [-2, -1]}], - n_steps=10, + n_steps=4, ) grads_wrt_xs = pt.grad(ys.sum(), wrt=xs) - compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)]) + f = function([xs], grads_wrt_xs, mode="JAX") + f(np.arange(4).reshape((2, 2))) + # compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(2)]) def test_scan_update():