Skip to content

Commit b68f493

Browse files
Add tests for AdvancedIncSubtensor and SetSubtensor
1 parent 08c2d66 commit b68f493

File tree

1 file changed

+61
-34
lines changed

1 file changed

+61
-34
lines changed

tests/link/pytorch/test_subtensor.py

+61-34
Original file line numberDiff line numberDiff line change
@@ -100,101 +100,128 @@ def test_pytorch_AdvSubtensor():
100100
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
101101
compare_pytorch_and_py(out_fg, [x_np, a_np])
102102

103+
with pytest.raises(
104+
NotImplementedError, match="Negative step sizes are not supported in Pytorch"
105+
):
106+
out_pt = x_pt[[1, 2], ::-1]
107+
out_fg = FunctionGraph([x_pt], [out_pt])
108+
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
109+
compare_pytorch_and_py(out_fg, [x_np])
103110

104-
def test_pytorch_IncSubtensor():
105-
rng = np.random.default_rng(42)
106111

107-
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
108-
x_pt = pt.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
112+
def test_pytorch_SetSubtensor():
113+
x_pt = pt.tensor3("x")
114+
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
109115

110116
# "Set" basic indices
111117
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
112118
out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt)
113119
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
114-
out_fg = FunctionGraph([], [out_pt])
115-
compare_pytorch_and_py(out_fg, [])
120+
out_fg = FunctionGraph([x_pt], [out_pt])
121+
compare_pytorch_and_py(out_fg, [x_test])
116122

117123
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
118124
out_pt = pt_subtensor.set_subtensor(x_pt[:2, 0, 0], st_pt)
119125
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
120-
out_fg = FunctionGraph([], [out_pt])
121-
compare_pytorch_and_py(out_fg, [])
126+
out_fg = FunctionGraph([x_pt], [out_pt])
127+
compare_pytorch_and_py(out_fg, [x_test])
128+
129+
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
130+
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
131+
out_fg = FunctionGraph([x_pt], [out_pt])
132+
compare_pytorch_and_py(out_fg, [x_test])
122133

123134
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
124135
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
125-
out_fg = FunctionGraph([], [out_pt])
126-
compare_pytorch_and_py(out_fg, [])
136+
out_fg = FunctionGraph([x_pt], [out_pt])
137+
compare_pytorch_and_py(out_fg, [x_test])
138+
139+
140+
def test_pytorch_AdvSetSubtensor():
141+
rng = np.random.default_rng(42)
142+
143+
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
144+
x_pt = pt.tensor3("x")
145+
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
127146

128147
# "Set" advanced indices
129148
st_pt = pt.as_tensor_variable(
130149
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
131150
)
132151
out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt)
133152
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
134-
out_fg = FunctionGraph([], [out_pt])
135-
compare_pytorch_and_py(out_fg, [])
153+
out_fg = FunctionGraph([x_pt], [out_pt])
154+
compare_pytorch_and_py(out_fg, [x_test])
136155

137156
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
138157
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, 0], st_pt)
139158
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
140-
out_fg = FunctionGraph([], [out_pt])
141-
compare_pytorch_and_py(out_fg, [])
159+
out_fg = FunctionGraph([x_pt], [out_pt])
160+
compare_pytorch_and_py(out_fg, [x_test])
142161

143162
# "Set" boolean indices
144163
mask_pt = pt.constant(x_np > 0)
145164
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0)
146165
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
147-
out_fg = FunctionGraph([], [out_pt])
148-
compare_pytorch_and_py(out_fg, [])
166+
out_fg = FunctionGraph([x_pt], [out_pt])
167+
compare_pytorch_and_py(out_fg, [x_test])
168+
169+
170+
def test_pytorch_IncSubtensor():
171+
x_pt = pt.tensor3("x")
172+
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
149173

150174
# "Increment" basic indices
151175
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
152176
out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt)
153177
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
154-
out_fg = FunctionGraph([], [out_pt])
155-
compare_pytorch_and_py(out_fg, [])
178+
out_fg = FunctionGraph([x_pt], [out_pt])
179+
compare_pytorch_and_py(out_fg, [x_test])
156180

157181
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
158182
out_pt = pt_subtensor.inc_subtensor(x_pt[:2, 0, 0], st_pt)
159183
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
160-
out_fg = FunctionGraph([], [out_pt])
161-
compare_pytorch_and_py(out_fg, [])
184+
out_fg = FunctionGraph([x_pt], [out_pt])
185+
compare_pytorch_and_py(out_fg, [x_test])
162186

163-
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
164-
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
165-
out_fg = FunctionGraph([], [out_pt])
166-
compare_pytorch_and_py(out_fg, [])
187+
188+
def test_pytorch_AvdancedIncSubtensor():
189+
rng = np.random.default_rng(42)
190+
191+
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
192+
x_pt = pt.tensor3("x")
193+
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
167194

168195
# "Increment" advanced indices
169196
st_pt = pt.as_tensor_variable(
170197
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
171198
)
172199
out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt)
173200
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
174-
out_fg = FunctionGraph([], [out_pt])
175-
compare_pytorch_and_py(out_fg, [])
201+
out_fg = FunctionGraph([x_pt], [out_pt])
202+
compare_pytorch_and_py(out_fg, [x_test])
176203

177204
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
178205
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, 0], st_pt)
179206
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
180-
out_fg = FunctionGraph([], [out_pt])
181-
compare_pytorch_and_py(out_fg, [])
207+
out_fg = FunctionGraph([x_pt], [out_pt])
208+
compare_pytorch_and_py(out_fg, [x_test])
182209

183210
# "Increment" boolean indices
184211
mask_pt = pt.constant(x_np > 0)
185212
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 1.0)
186213
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
187-
out_fg = FunctionGraph([], [out_pt])
188-
compare_pytorch_and_py(out_fg, [])
214+
out_fg = FunctionGraph([x_pt], [out_pt])
215+
compare_pytorch_and_py(out_fg, [x_test])
189216

190217
st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
191218
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, :3], st_pt)
192219
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
193-
out_fg = FunctionGraph([], [out_pt])
194-
compare_pytorch_and_py(out_fg, [])
220+
out_fg = FunctionGraph([x_pt], [out_pt])
221+
compare_pytorch_and_py(out_fg, [x_test])
195222

196223
st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
197224
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, :3], st_pt)
198225
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
199-
out_fg = FunctionGraph([], [out_pt])
200-
compare_pytorch_and_py(out_fg, [])
226+
out_fg = FunctionGraph([x_pt], [out_pt])
227+
compare_pytorch_and_py(out_fg, [x_test])

0 commit comments

Comments
 (0)