Skip to content

Commit d454052

Browse files
Parametrize tests
1 parent b68f493 commit d454052

File tree

2 files changed

+53
-97
lines changed

2 files changed

+53
-97
lines changed

pytensor/link/pytorch/dispatch/subtensor.py

+38-29
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def subtensor(x, *ilists):
3737
return subtensor
3838

3939

40+
@pytorch_funcify.register(MakeSlice)
41+
def pytorch_funcify_makeslice(op, **kwargs):
42+
def makeslice(*x):
43+
return slice(x)
44+
45+
return makeslice
46+
47+
4048
@pytorch_funcify.register(AdvancedSubtensor1)
4149
@pytorch_funcify.register(AdvancedSubtensor)
4250
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
@@ -47,38 +55,35 @@ def advsubtensor(x, *indices):
4755
return advsubtensor
4856

4957

50-
@pytorch_funcify.register(MakeSlice)
51-
def pytorch_funcify_makeslice(op, **kwargs):
52-
def makeslice(*x):
53-
return slice(x)
54-
55-
return makeslice
56-
57-
5858
@pytorch_funcify.register(IncSubtensor)
5959
def pytorch_funcify_IncSubtensor(op, node, **kwargs):
6060
idx_list = getattr(op, "idx_list", None)
6161

6262
if getattr(op, "set_instead_of_inc", False):
6363

6464
def torch_fn(x, indices, y):
65-
check_negative_steps(indices)
66-
x[indices] = y
67-
return x
65+
if op.inplace:
66+
x[tuple(indices)] = y
67+
return x
68+
x1 = x.clone()
69+
x1[tuple(indices)] = y
70+
return x1
6871

6972
else:
7073

7174
def torch_fn(x, indices, y):
72-
check_negative_steps(indices)
73-
x1 = x.clone()
74-
x1[indices] += y
75-
return x1
76-
77-
def incsubtensor(x, y, *ilist, torch_fn=torch_fn, idx_list=idx_list):
75+
if op.inplace:
76+
x[tuple(indices)] += y
77+
return x
78+
else:
79+
x1 = x.clone()
80+
x1[tuple(indices)] += y
81+
return x1
82+
83+
def incsubtensor(x, y, *ilist):
7884
indices = indices_from_subtensor(ilist, idx_list)
79-
if len(indices) == 1:
80-
indices = indices[0]
8185

86+
check_negative_steps(indices)
8287
return torch_fn(x, indices, y)
8388

8489
return incsubtensor
@@ -90,22 +95,26 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
9095
if getattr(op, "set_instead_of_inc", False):
9196

9297
def torch_fn(x, indices, y):
93-
check_negative_steps(indices)
94-
x[indices] = y
95-
return x
98+
if op.inplace:
99+
x[tuple(indices)] = y
100+
return x
101+
x1 = x.clone()
102+
x1[tuple(indices)] = y
103+
return x1
96104

97105
else:
98106

99107
def torch_fn(x, indices, y):
100-
check_negative_steps(indices)
101-
x1 = x.clone()
102-
x1[indices] += y
103-
return x1
108+
if op.inplace:
109+
x[tuple(indices)] += y
110+
return x
111+
else:
112+
x1 = x.clone()
113+
x1[tuple(indices)] += y
114+
return x1
104115

105116
def incsubtensor(x, y, *indices, torch_fn=torch_fn):
106-
if len(indices) == 1:
107-
indices = indices[0]
108-
117+
check_negative_steps(indices)
109118
return torch_fn(x, indices, y)
110119

111120
return incsubtensor

tests/link/pytorch/test_subtensor.py

+15-68
Original file line numberDiff line numberDiff line change
@@ -109,83 +109,36 @@ def test_pytorch_AdvSubtensor():
109109
compare_pytorch_and_py(out_fg, [x_np])
110110

111111

112-
def test_pytorch_SetSubtensor():
112+
@pytest.mark.parametrize(
113+
"subtensor_op", [pt_subtensor.set_subtensor, pt_subtensor.inc_subtensor]
114+
)
115+
def test_pytorch_SetSubtensor(subtensor_op):
113116
x_pt = pt.tensor3("x")
114117
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
115118

116119
# "Set" basic indices
117120
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
118-
out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt)
121+
out_pt = subtensor_op(x_pt[1, 2, 3], st_pt)
119122
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
120123
out_fg = FunctionGraph([x_pt], [out_pt])
121124
compare_pytorch_and_py(out_fg, [x_test])
122125

123126
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
124-
out_pt = pt_subtensor.set_subtensor(x_pt[:2, 0, 0], st_pt)
127+
out_pt = subtensor_op(x_pt[:2, 0, 0], st_pt)
125128
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
126129
out_fg = FunctionGraph([x_pt], [out_pt])
127130
compare_pytorch_and_py(out_fg, [x_test])
128131

129-
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
132+
out_pt = subtensor_op(x_pt[0, 1:3, 0], st_pt)
130133
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
131134
out_fg = FunctionGraph([x_pt], [out_pt])
132135
compare_pytorch_and_py(out_fg, [x_test])
133136

134-
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
135-
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
136-
out_fg = FunctionGraph([x_pt], [out_pt])
137-
compare_pytorch_and_py(out_fg, [x_test])
138137

139-
140-
def test_pytorch_AdvSetSubtensor():
141-
rng = np.random.default_rng(42)
142-
143-
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
144-
x_pt = pt.tensor3("x")
145-
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
146-
147-
# "Set" advanced indices
148-
st_pt = pt.as_tensor_variable(
149-
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
150-
)
151-
out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt)
152-
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
153-
out_fg = FunctionGraph([x_pt], [out_pt])
154-
compare_pytorch_and_py(out_fg, [x_test])
155-
156-
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
157-
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, 0], st_pt)
158-
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
159-
out_fg = FunctionGraph([x_pt], [out_pt])
160-
compare_pytorch_and_py(out_fg, [x_test])
161-
162-
# "Set" boolean indices
163-
mask_pt = pt.constant(x_np > 0)
164-
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0)
165-
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
166-
out_fg = FunctionGraph([x_pt], [out_pt])
167-
compare_pytorch_and_py(out_fg, [x_test])
168-
169-
170-
def test_pytorch_IncSubtensor():
171-
x_pt = pt.tensor3("x")
172-
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
173-
174-
# "Increment" basic indices
175-
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
176-
out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt)
177-
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
178-
out_fg = FunctionGraph([x_pt], [out_pt])
179-
compare_pytorch_and_py(out_fg, [x_test])
180-
181-
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
182-
out_pt = pt_subtensor.inc_subtensor(x_pt[:2, 0, 0], st_pt)
183-
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
184-
out_fg = FunctionGraph([x_pt], [out_pt])
185-
compare_pytorch_and_py(out_fg, [x_test])
186-
187-
188-
def test_pytorch_AvdancedIncSubtensor():
138+
@pytest.mark.parametrize(
139+
"advsubtensor_op", [pt_subtensor.set_subtensor, pt_subtensor.inc_subtensor]
140+
)
141+
def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
189142
rng = np.random.default_rng(42)
190143

191144
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
@@ -196,32 +149,26 @@ def test_pytorch_AvdancedIncSubtensor():
196149
st_pt = pt.as_tensor_variable(
197150
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
198151
)
199-
out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt)
152+
out_pt = advsubtensor_op(x_pt[np.r_[0, 2]], st_pt)
200153
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
201154
out_fg = FunctionGraph([x_pt], [out_pt])
202155
compare_pytorch_and_py(out_fg, [x_test])
203156

204157
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
205-
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, 0], st_pt)
158+
out_pt = advsubtensor_op(x_pt[[0, 2], 0, 0], st_pt)
206159
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
207160
out_fg = FunctionGraph([x_pt], [out_pt])
208161
compare_pytorch_and_py(out_fg, [x_test])
209162

210163
# "Increment" boolean indices
211164
mask_pt = pt.constant(x_np > 0)
212-
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 1.0)
213-
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
214-
out_fg = FunctionGraph([x_pt], [out_pt])
215-
compare_pytorch_and_py(out_fg, [x_test])
216-
217-
st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
218-
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, :3], st_pt)
165+
out_pt = advsubtensor_op(x_pt[mask_pt], 1.0)
219166
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
220167
out_fg = FunctionGraph([x_pt], [out_pt])
221168
compare_pytorch_and_py(out_fg, [x_test])
222169

223170
st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
224-
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, :3], st_pt)
171+
out_pt = advsubtensor_op(x_pt[[0, 2], 0, :3], st_pt)
225172
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
226173
out_fg = FunctionGraph([x_pt], [out_pt])
227174
compare_pytorch_and_py(out_fg, [x_test])

0 commit comments

Comments
 (0)