Skip to content

Commit 330a7d2

Browse files
committed
Parametrized tests for Repeat and Unique impls. in PyTorch
1 parent dfdc114 commit 330a7d2

File tree

1 file changed

+30
-25
lines changed

1 file changed

+30
-25
lines changed

tests/link/pytorch/test_extra_ops.py

+30-25
Original file line numberDiff line numberDiff line change
@@ -43,47 +43,52 @@ def test_pytorch_CumOp(axis, dtype):
4343
compare_pytorch_and_py(fgraph, [test_value])
4444

4545

46-
def test_pytorch_Repeat():
46+
@pytest.mark.parametrize("axis", [0, 1])
47+
def test_pytorch_Repeat(axis):
4748
a = pt.matrix("a", dtype="float64")
4849

4950
test_value = np.arange(6, dtype="float64").reshape((3, 2))
5051

51-
# Test along axis 0
52-
out = pt.repeat(a, (1, 2, 3), axis=0)
52+
out = pt.repeat(a, (1, 2, 3) if axis == 0 else (3, 3), axis=axis)
5353
fgraph = FunctionGraph([a], [out])
5454
compare_pytorch_and_py(fgraph, [test_value])
5555

56-
# Test along axis 1
57-
out = pt.repeat(a, (3, 3), axis=1)
58-
fgraph = FunctionGraph([a], [out])
59-
compare_pytorch_and_py(fgraph, [test_value])
6056

61-
62-
def test_pytorch_Unique():
57+
@pytest.mark.parametrize("axis", [0, 1])
58+
def test_pytorch_Unique_axis(axis):
6359
a = pt.matrix("a", dtype="float64")
6460

6561
test_value = np.array(
6662
[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64"
6763
)
6864

69-
# Test along axis 0
70-
out = pt.unique(a, axis=0)
71-
fgraph = FunctionGraph([a], [out])
72-
compare_pytorch_and_py(fgraph, [test_value])
73-
74-
# Test along axis 1
75-
out = pt.unique(a, axis=1)
65+
out = pt.unique(a, axis=axis)
7666
fgraph = FunctionGraph([a], [out])
7767
compare_pytorch_and_py(fgraph, [test_value])
7868

79-
# Test with params
80-
out = pt.unique(a, return_inverse=True, return_counts=True, axis=0)
81-
fgraph = FunctionGraph([a], [out[0]])
82-
compare_pytorch_and_py(fgraph, [test_value])
8369

84-
# Test with return_index=True
85-
out = pt.unique(a, return_index=True, axis=0)
86-
fgraph = FunctionGraph([a], [out[0]])
70+
@pytest.mark.parametrize(
71+
"return_index, return_inverse, return_counts",
72+
[
73+
(False, True, False),
74+
(False, True, True),
75+
pytest.param(
76+
True, False, False, marks=pytest.mark.xfail(raises=NotImplementedError)
77+
),
78+
],
79+
)
80+
def test_pytorch_Unique_params(return_index, return_inverse, return_counts):
81+
a = pt.matrix("a", dtype="float64")
82+
test_value = np.array(
83+
[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64"
84+
)
8785

88-
with pytest.raises(NotImplementedError):
89-
compare_pytorch_and_py(fgraph, [test_value])
86+
out = pt.unique(
87+
a,
88+
return_index=return_index,
89+
return_inverse=return_inverse,
90+
return_counts=return_counts,
91+
axis=0,
92+
)
93+
fgraph = FunctionGraph([a], [out[0] if isinstance(out, list) else out])
94+
compare_pytorch_and_py(fgraph, [test_value])

0 commit comments

Comments
 (0)