Skip to content

Commit d8fb0e4

Browse files
committed
test abstract size
1 parent ed58d68 commit d8fb0e4

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tests/test_funcs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,27 @@ def test_basic(self, xp: ModuleType):
474474
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
475475
xp_assert_equal(actual, expected)
476476

477+
def test_2d(self, xp: ModuleType):
478+
actual = one_hot(xp.asarray([[2, 1, 0], [1, 0, 2]]), 3, axis=1)
479+
expected = xp.asarray(
480+
[
481+
[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]],
482+
[[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
483+
]
484+
)
485+
xp_assert_equal(actual, expected)
486+
487+
@pytest.mark.skip_xp_backend(
488+
Backend.ARRAY_API_STRICTEST, reason="backend doesn't support Boolean indexing"
489+
)
490+
def test_abstract_size(self, xp: ModuleType):
491+
x = xp.arange(5)
492+
x = x[x > 2]
493+
x = xp.astype(x, xp.int64)
494+
actual = one_hot(x, 5)
495+
expected = xp.asarray([[0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]])
496+
xp_assert_equal(actual, expected)
497+
477498
@pytest.mark.skip_xp_backend(
478499
Backend.TORCH_GPU, reason="Puts Pytorch into a bad state."
479500
)

0 commit comments

Comments
 (0)