@@ -474,6 +474,27 @@ def test_basic(self, xp: ModuleType):
474
474
expected = xp .asarray ([[0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [1.0 , 0.0 , 0.0 ]])
475
475
xp_assert_equal (actual , expected )
476
476
477
+ def test_2d (self , xp : ModuleType ):
478
+ actual = one_hot (xp .asarray ([[2 , 1 , 0 ], [1 , 0 , 2 ]]), 3 , axis = 1 )
479
+ expected = xp .asarray (
480
+ [
481
+ [[0.0 , 0.0 , 1.0 ], [0.0 , 1.0 , 0.0 ], [1.0 , 0.0 , 0.0 ]],
482
+ [[0.0 , 1.0 , 0.0 ], [1.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]],
483
+ ]
484
+ )
485
+ xp_assert_equal (actual , expected )
486
+
487
+ @pytest .mark .skip_xp_backend (
488
+ Backend .ARRAY_API_STRICTEST , reason = "backend doesn't support Boolean indexing"
489
+ )
490
+ def test_abstract_size (self , xp : ModuleType ):
491
+ x = xp .arange (5 )
492
+ x = x [x > 2 ]
493
+ x = xp .astype (x , xp .int64 )
494
+ actual = one_hot (x , 5 )
495
+ expected = xp .asarray ([[0.0 , 0.0 , 0.0 , 1.0 , 0.0 ], [0.0 , 0.0 , 0.0 , 0.0 , 1.0 ]])
496
+ xp_assert_equal (actual , expected )
497
+
477
498
@pytest .mark .skip_xp_backend (
478
499
Backend .TORCH_GPU , reason = "Puts Pytorch into a bad state."
479
500
)
0 commit comments