Skip to content

Commit 08c2d66

Browse files
Add helper function and implement AdvancedIncSubtensor
1 parent 7fd83dc commit 08c2d66

File tree

1 file changed

+48
-15
lines changed

1 file changed

+48
-15
lines changed
+48-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
22
from pytensor.tensor.subtensor import (
3+
AdvancedIncSubtensor,
34
AdvancedIncSubtensor1,
45
AdvancedSubtensor,
56
AdvancedSubtensor1,
@@ -10,18 +11,26 @@
1011
from pytensor.tensor.type_other import MakeSlice
1112

1213

14+
def check_negative_steps(index):
15+
for i in index:
16+
if isinstance(i, slice):
17+
if i.step and isinstance(i.step, int) and i.step < 0:
18+
raise NotImplementedError(
19+
"Negative step sizes are not supported in Pytorch"
20+
)
21+
elif i.step and not isinstance(i.step, int):
22+
raise NotImplementedError(
23+
"Negative step sizes are not supported in Pytorch"
24+
)
25+
26+
1327
@pytorch_funcify.register(Subtensor)
1428
def pytorch_funcify_Subtensor(op, node, **kwargs):
1529
idx_list = getattr(op, "idx_list", None)
1630

1731
def subtensor(x, *ilists):
1832
indices = indices_from_subtensor(ilists, idx_list)
19-
for i in indices:
20-
if isinstance(i, slice):
21-
if i.step and i.step < 0:
22-
raise NotImplementedError(
23-
"Negative step sizes are not supported in Pytorch"
24-
)
33+
check_negative_steps(indices)
2534

2635
return x[indices]
2736

@@ -32,12 +41,7 @@ def subtensor(x, *ilists):
3241
@pytorch_funcify.register(AdvancedSubtensor)
3342
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
3443
def advsubtensor(x, *indices):
35-
for i in indices:
36-
if isinstance(i, slice):
37-
if i.step and i.step < 0:
38-
raise NotImplementedError(
39-
"Negative step sizes are not supported in Pytorch"
40-
)
44+
check_negative_steps(indices)
4145
return x[indices]
4246

4347
return advsubtensor
@@ -52,21 +56,23 @@ def makeslice(*x):
5256

5357

5458
@pytorch_funcify.register(IncSubtensor)
55-
@pytorch_funcify.register(AdvancedIncSubtensor1)
5659
def pytorch_funcify_IncSubtensor(op, node, **kwargs):
5760
idx_list = getattr(op, "idx_list", None)
5861

5962
if getattr(op, "set_instead_of_inc", False):
6063

6164
def torch_fn(x, indices, y):
65+
check_negative_steps(indices)
6266
x[indices] = y
6367
return x
6468

6569
else:
6670

6771
def torch_fn(x, indices, y):
68-
x[indices] += y
69-
return x
72+
check_negative_steps(indices)
73+
x1 = x.clone()
74+
x1[indices] += y
75+
return x1
7076

7177
def incsubtensor(x, y, *ilist, torch_fn=torch_fn, idx_list=idx_list):
7278
indices = indices_from_subtensor(ilist, idx_list)
@@ -76,3 +82,30 @@ def incsubtensor(x, y, *ilist, torch_fn=torch_fn, idx_list=idx_list):
7682
return torch_fn(x, indices, y)
7783

7884
return incsubtensor
85+
86+
87+
@pytorch_funcify.register(AdvancedIncSubtensor)
88+
@pytorch_funcify.register(AdvancedIncSubtensor1)
89+
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
90+
if getattr(op, "set_instead_of_inc", False):
91+
92+
def torch_fn(x, indices, y):
93+
check_negative_steps(indices)
94+
x[indices] = y
95+
return x
96+
97+
else:
98+
99+
def torch_fn(x, indices, y):
100+
check_negative_steps(indices)
101+
x1 = x.clone()
102+
x1[indices] += y
103+
return x1
104+
105+
def incsubtensor(x, y, *indices, torch_fn=torch_fn):
106+
if len(indices) == 1:
107+
indices = indices[0]
108+
109+
return torch_fn(x, indices, y)
110+
111+
return incsubtensor

0 commit comments

Comments
 (0)