Skip to content

Commit 64bad71

Browse files
Fix JAX conversion for AdvancedIncSubtensor1
1 parent 0813ec0 commit 64bad71

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

tests/sandbox/test_jax.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -425,40 +425,40 @@ def test_jax_Subtensors():
425425
# Basic indices
426426
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5))
427427
out_tt = x_tt[1, 2, 0]
428-
428+
assert isinstance(out_tt.owner.op, tt.subtensor.Subtensor)
429429
out_fg = theano.gof.FunctionGraph([], [out_tt])
430430
compare_jax_and_py(out_fg, [])
431431

432432
out_tt = x_tt[1:2, 1, :]
433-
433+
assert isinstance(out_tt.owner.op, tt.subtensor.Subtensor)
434434
out_fg = theano.gof.FunctionGraph([], [out_tt])
435435
compare_jax_and_py(out_fg, [])
436436

437437
# Boolean indices
438438
out_tt = x_tt[x_tt < 0]
439-
439+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedSubtensor)
440440
out_fg = theano.gof.FunctionGraph([], [out_tt])
441441
compare_jax_and_py(out_fg, [])
442442

443443
# Advanced indexing
444444
out_tt = x_tt[[1, 2]]
445-
445+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedSubtensor1)
446446
out_fg = theano.gof.FunctionGraph([], [out_tt])
447447
compare_jax_and_py(out_fg, [])
448448

449449
out_tt = x_tt[[1, 2], [2, 3]]
450-
450+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedSubtensor)
451451
out_fg = theano.gof.FunctionGraph([], [out_tt])
452452
compare_jax_and_py(out_fg, [])
453453

454454
# Advanced and basic indexing
455455
out_tt = x_tt[[1, 2], :]
456-
456+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedSubtensor1)
457457
out_fg = theano.gof.FunctionGraph([], [out_tt])
458458
compare_jax_and_py(out_fg, [])
459459

460460
out_tt = x_tt[[1, 2], :, [3, 4]]
461-
461+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedSubtensor)
462462
out_fg = theano.gof.FunctionGraph([], [out_tt])
463463
compare_jax_and_py(out_fg, [])
464464

@@ -470,64 +470,92 @@ def test_jax_IncSubtensor():
470470
# "Set" basic indices
471471
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX))
472472
out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt)
473+
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
473474
out_fg = theano.gof.FunctionGraph([], [out_tt])
474475
compare_jax_and_py(out_fg, [])
475476

476477
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
477478
out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt)
479+
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
478480
out_fg = theano.gof.FunctionGraph([], [out_tt])
479481
compare_jax_and_py(out_fg, [])
480482

481483
out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt)
484+
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
482485
out_fg = theano.gof.FunctionGraph([], [out_tt])
483486
compare_jax_and_py(out_fg, [])
484487

485488
# "Set" advanced indices
489+
st_tt = tt.as_tensor_variable(
490+
np.random.uniform(-1, 1, size=(2, 4, 5)).astype(theano.config.floatX)
491+
)
492+
out_tt = tt.set_subtensor(x_tt[np.r_[0, 2]], st_tt)
493+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor1)
494+
out_fg = theano.gof.FunctionGraph([], [out_tt])
495+
compare_jax_and_py(out_fg, [])
496+
486497
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
487498
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt)
499+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
488500
out_fg = theano.gof.FunctionGraph([], [out_tt])
489501
compare_jax_and_py(out_fg, [])
490502

491503
st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3])
492504
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, :3], st_tt)
505+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
493506
out_fg = theano.gof.FunctionGraph([], [out_tt])
494507
compare_jax_and_py(out_fg, [])
495508

496509
# "Set" boolean indices
497510
mask_tt = tt.as_tensor_variable(x_np) > 0
498511
out_tt = tt.set_subtensor(x_tt[mask_tt], 0.0)
512+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
499513
out_fg = theano.gof.FunctionGraph([], [out_tt])
500514
compare_jax_and_py(out_fg, [])
501515

502516
# "Increment" basic indices
503517
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX))
504518
out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt)
519+
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
505520
out_fg = theano.gof.FunctionGraph([], [out_tt])
506521
compare_jax_and_py(out_fg, [])
507522

508523
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
509524
out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt)
525+
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
510526
out_fg = theano.gof.FunctionGraph([], [out_tt])
511527
compare_jax_and_py(out_fg, [])
512528

513529
out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt)
530+
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
514531
out_fg = theano.gof.FunctionGraph([], [out_tt])
515532
compare_jax_and_py(out_fg, [])
516533

517534
# "Increment" advanced indices
535+
st_tt = tt.as_tensor_variable(
536+
np.random.uniform(-1, 1, size=(2, 4, 5)).astype(theano.config.floatX)
537+
)
538+
out_tt = tt.inc_subtensor(x_tt[np.r_[0, 2]], st_tt)
539+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor1)
540+
out_fg = theano.gof.FunctionGraph([], [out_tt])
541+
compare_jax_and_py(out_fg, [])
542+
518543
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
519544
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt)
545+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
520546
out_fg = theano.gof.FunctionGraph([], [out_tt])
521547
compare_jax_and_py(out_fg, [])
522548

523549
st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3])
524550
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, :3], st_tt)
551+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
525552
out_fg = theano.gof.FunctionGraph([], [out_tt])
526553
compare_jax_and_py(out_fg, [])
527554

528555
# "Increment" boolean indices
529556
mask_tt = tt.as_tensor_variable(x_np) > 0
530557
out_tt = tt.set_subtensor(x_tt[mask_tt], 1.0)
558+
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
531559
out_fg = theano.gof.FunctionGraph([], [out_tt])
532560
compare_jax_and_py(out_fg, [])
533561

theano/sandbox/jaxify.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def subtensor(x, *ilists):
623623

624624
def jax_funcify_IncSubtensor(op):
625625

626-
idx_list = op.idx_list
626+
idx_list = getattr(op, "idx_list", None)
627627

628628
if getattr(op, "set_instead_of_inc", False):
629629
jax_fn = jax.ops.index_update
@@ -632,7 +632,11 @@ def jax_funcify_IncSubtensor(op):
632632

633633
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
634634
_ilist = list(ilist)
635-
cdata = tuple(convert_indices(_ilist, idx) for idx in idx_list)
635+
cdata = (
636+
tuple(convert_indices(_ilist, idx) for idx in idx_list)
637+
if idx_list
638+
else _ilist
639+
)
636640
if len(cdata) == 1:
637641
cdata = cdata[0]
638642

0 commit comments

Comments
 (0)