@@ -425,40 +425,40 @@ def test_jax_Subtensors():
425
425
# Basic indices
426
426
x_tt = tt .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 ))
427
427
out_tt = x_tt [1 , 2 , 0 ]
428
-
428
+ assert isinstance ( out_tt . owner . op , tt . subtensor . Subtensor )
429
429
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
430
430
compare_jax_and_py (out_fg , [])
431
431
432
432
out_tt = x_tt [1 :2 , 1 , :]
433
-
433
+ assert isinstance ( out_tt . owner . op , tt . subtensor . Subtensor )
434
434
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
435
435
compare_jax_and_py (out_fg , [])
436
436
437
437
# Boolean indices
438
438
out_tt = x_tt [x_tt < 0 ]
439
-
439
+ assert isinstance ( out_tt . owner . op , tt . subtensor . AdvancedSubtensor )
440
440
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
441
441
compare_jax_and_py (out_fg , [])
442
442
443
443
# Advanced indexing
444
444
out_tt = x_tt [[1 , 2 ]]
445
-
445
+ assert isinstance ( out_tt . owner . op , tt . subtensor . AdvancedSubtensor1 )
446
446
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
447
447
compare_jax_and_py (out_fg , [])
448
448
449
449
out_tt = x_tt [[1 , 2 ], [2 , 3 ]]
450
-
450
+ assert isinstance ( out_tt . owner . op , tt . subtensor . AdvancedSubtensor )
451
451
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
452
452
compare_jax_and_py (out_fg , [])
453
453
454
454
# Advanced and basic indexing
455
455
out_tt = x_tt [[1 , 2 ], :]
456
-
456
+ assert isinstance ( out_tt . owner . op , tt . subtensor . AdvancedSubtensor1 )
457
457
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
458
458
compare_jax_and_py (out_fg , [])
459
459
460
460
out_tt = x_tt [[1 , 2 ], :, [3 , 4 ]]
461
-
461
+ assert isinstance ( out_tt . owner . op , tt . subtensor . AdvancedSubtensor )
462
462
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
463
463
compare_jax_and_py (out_fg , [])
464
464
@@ -470,64 +470,92 @@ def test_jax_IncSubtensor():
470
470
# "Set" basic indices
471
471
st_tt = tt .as_tensor_variable (np .array (- 10.0 , dtype = theano .config .floatX ))
472
472
out_tt = tt .set_subtensor (x_tt [1 , 2 , 3 ], st_tt )
473
+ assert isinstance (out_tt .owner .op , tt .subtensor .IncSubtensor )
473
474
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
474
475
compare_jax_and_py (out_fg , [])
475
476
476
477
st_tt = tt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (theano .config .floatX ))
477
478
out_tt = tt .set_subtensor (x_tt [:2 , 0 , 0 ], st_tt )
479
+ assert isinstance (out_tt .owner .op , tt .subtensor .IncSubtensor )
478
480
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
479
481
compare_jax_and_py (out_fg , [])
480
482
481
483
out_tt = tt .set_subtensor (x_tt [0 , 1 :3 , 0 ], st_tt )
484
+ assert isinstance (out_tt .owner .op , tt .subtensor .IncSubtensor )
482
485
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
483
486
compare_jax_and_py (out_fg , [])
484
487
485
488
# "Set" advanced indices
489
+ st_tt = tt .as_tensor_variable (
490
+ np .random .uniform (- 1 , 1 , size = (2 , 4 , 5 )).astype (theano .config .floatX )
491
+ )
492
+ out_tt = tt .set_subtensor (x_tt [np .r_ [0 , 2 ]], st_tt )
493
+ assert isinstance (out_tt .owner .op , tt .subtensor .AdvancedIncSubtensor1 )
494
+ out_fg = theano .gof .FunctionGraph ([], [out_tt ])
495
+ compare_jax_and_py (out_fg , [])
496
+
486
497
st_tt = tt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (theano .config .floatX ))
487
498
out_tt = tt .set_subtensor (x_tt [[0 , 2 ], 0 , 0 ], st_tt )
499
+ assert isinstance (out_tt .owner .op , tt .subtensor .AdvancedIncSubtensor )
488
500
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
489
501
compare_jax_and_py (out_fg , [])
490
502
491
503
st_tt = tt .as_tensor_variable (x_np [[0 , 2 ], 0 , :3 ])
492
504
out_tt = tt .set_subtensor (x_tt [[0 , 2 ], 0 , :3 ], st_tt )
505
+ assert isinstance (out_tt .owner .op , tt .subtensor .AdvancedIncSubtensor )
493
506
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
494
507
compare_jax_and_py (out_fg , [])
495
508
496
509
# "Set" boolean indices
497
510
mask_tt = tt .as_tensor_variable (x_np ) > 0
498
511
out_tt = tt .set_subtensor (x_tt [mask_tt ], 0.0 )
512
+ assert isinstance (out_tt .owner .op , tt .subtensor .AdvancedIncSubtensor )
499
513
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
500
514
compare_jax_and_py (out_fg , [])
501
515
502
516
# "Increment" basic indices
503
517
st_tt = tt .as_tensor_variable (np .array (- 10.0 , dtype = theano .config .floatX ))
504
518
out_tt = tt .inc_subtensor (x_tt [1 , 2 , 3 ], st_tt )
519
+ assert isinstance (out_tt .owner .op , tt .subtensor .IncSubtensor )
505
520
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
506
521
compare_jax_and_py (out_fg , [])
507
522
508
523
st_tt = tt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (theano .config .floatX ))
509
524
out_tt = tt .inc_subtensor (x_tt [:2 , 0 , 0 ], st_tt )
525
+ assert isinstance (out_tt .owner .op , tt .subtensor .IncSubtensor )
510
526
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
511
527
compare_jax_and_py (out_fg , [])
512
528
513
529
out_tt = tt .set_subtensor (x_tt [0 , 1 :3 , 0 ], st_tt )
530
+ assert isinstance (out_tt .owner .op , tt .subtensor .IncSubtensor )
514
531
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
515
532
compare_jax_and_py (out_fg , [])
516
533
517
534
# "Increment" advanced indices
535
+ st_tt = tt .as_tensor_variable (
536
+ np .random .uniform (- 1 , 1 , size = (2 , 4 , 5 )).astype (theano .config .floatX )
537
+ )
538
+ out_tt = tt .inc_subtensor (x_tt [np .r_ [0 , 2 ]], st_tt )
539
+ assert isinstance (out_tt .owner .op , tt .subtensor .AdvancedIncSubtensor1 )
540
+ out_fg = theano .gof .FunctionGraph ([], [out_tt ])
541
+ compare_jax_and_py (out_fg , [])
542
+
518
543
st_tt = tt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (theano .config .floatX ))
519
544
out_tt = tt .inc_subtensor (x_tt [[0 , 2 ], 0 , 0 ], st_tt )
545
+ assert isinstance (out_tt .owner .op , tt .subtensor .AdvancedIncSubtensor )
520
546
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
521
547
compare_jax_and_py (out_fg , [])
522
548
523
549
st_tt = tt .as_tensor_variable (x_np [[0 , 2 ], 0 , :3 ])
524
550
out_tt = tt .inc_subtensor (x_tt [[0 , 2 ], 0 , :3 ], st_tt )
551
+ assert isinstance (out_tt .owner .op , tt .subtensor .AdvancedIncSubtensor )
525
552
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
526
553
compare_jax_and_py (out_fg , [])
527
554
528
555
# "Increment" boolean indices
529
556
mask_tt = tt .as_tensor_variable (x_np ) > 0
530
557
out_tt = tt .set_subtensor (x_tt [mask_tt ], 1.0 )
558
+ assert isinstance (out_tt .owner .op , tt .subtensor .AdvancedIncSubtensor )
531
559
out_fg = theano .gof .FunctionGraph ([], [out_tt ])
532
560
compare_jax_and_py (out_fg , [])
533
561
0 commit comments