Skip to content

Commit 62e453d

Browse files
committed
Added test axis=None for Repeat in PyTorch
1 parent 8dbe5ca commit 62e453d

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

Diff for: tests/link/pytorch/test_extra_ops.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,29 @@ def test_pytorch_CumOp(axis, dtype):
4343
compare_pytorch_and_py(fgraph, [test_value])
4444

4545

46-
@pytest.mark.parametrize("axis", [0, 1])
47-
def test_pytorch_Repeat(axis):
46+
@pytest.mark.parametrize(
47+
"axis, repeats",
48+
[
49+
(0, (1, 2, 3)),
50+
(1, (3, 3)),
51+
pytest.param(
52+
None,
53+
3,
54+
marks=pytest.mark.xfail(reason="Reshape not implemented"),
55+
),
56+
],
57+
)
58+
def test_pytorch_Repeat(axis, repeats):
4859
a = pt.matrix("a", dtype="float64")
4960

5061
test_value = np.arange(6, dtype="float64").reshape((3, 2))
5162

52-
out = pt.repeat(a, (1, 2, 3) if axis == 0 else (3, 3), axis=axis)
63+
out = pt.repeat(a, repeats, axis=axis)
5364
fgraph = FunctionGraph([a], [out])
5465
compare_pytorch_and_py(fgraph, [test_value])
5566

5667

57-
@pytest.mark.parametrize("axis", [0, 1])
68+
@pytest.mark.parametrize("axis", [None, 0, 1])
5869
def test_pytorch_Unique_axis(axis):
5970
a = pt.matrix("a", dtype="float64")
6071

0 commit comments

Comments
 (0)