11from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
22from pytensor .tensor .subtensor import (
3+ AdvancedIncSubtensor ,
34 AdvancedIncSubtensor1 ,
45 AdvancedSubtensor ,
56 AdvancedSubtensor1 ,
1011from pytensor .tensor .type_other import MakeSlice
1112
1213
14+ def check_negative_steps (index ):
15+ for i in index :
16+ if isinstance (i , slice ):
17+ if i .step and isinstance (i .step , int ) and i .step < 0 :
18+ raise NotImplementedError (
19+ "Negative step sizes are not supported in Pytorch"
20+ )
21+ elif i .step and not isinstance (i .step , int ):
22+ raise NotImplementedError (
23+ "Negative step sizes are not supported in Pytorch"
24+ )
25+
26+
1327@pytorch_funcify .register (Subtensor )
1428def pytorch_funcify_Subtensor (op , node , ** kwargs ):
1529 idx_list = getattr (op , "idx_list" , None )
1630
1731 def subtensor (x , * ilists ):
1832 indices = indices_from_subtensor (ilists , idx_list )
19- for i in indices :
20- if isinstance (i , slice ):
21- if i .step and i .step < 0 :
22- raise NotImplementedError (
23- "Negative step sizes are not supported in Pytorch"
24- )
33+ check_negative_steps (indices )
2534
2635 return x [indices ]
2736
@@ -32,12 +41,7 @@ def subtensor(x, *ilists):
3241@pytorch_funcify .register (AdvancedSubtensor )
3342def pytorch_funcify_AdvSubtensor (op , node , ** kwargs ):
3443 def advsubtensor (x , * indices ):
35- for i in indices :
36- if isinstance (i , slice ):
37- if i .step and i .step < 0 :
38- raise NotImplementedError (
39- "Negative step sizes are not supported in Pytorch"
40- )
44+ check_negative_steps (indices )
4145 return x [indices ]
4246
4347 return advsubtensor
@@ -52,21 +56,23 @@ def makeslice(*x):
5256
5357
5458@pytorch_funcify .register (IncSubtensor )
55- @pytorch_funcify .register (AdvancedIncSubtensor1 )
5659def pytorch_funcify_IncSubtensor (op , node , ** kwargs ):
5760 idx_list = getattr (op , "idx_list" , None )
5861
5962 if getattr (op , "set_instead_of_inc" , False ):
6063
6164 def torch_fn (x , indices , y ):
65+ check_negative_steps (indices )
6266 x [indices ] = y
6367 return x
6468
6569 else :
6670
6771 def torch_fn (x , indices , y ):
68- x [indices ] += y
69- return x
72+ check_negative_steps (indices )
73+ x1 = x .clone ()
74+ x1 [indices ] += y
75+ return x1
7076
7177 def incsubtensor (x , y , * ilist , torch_fn = torch_fn , idx_list = idx_list ):
7278 indices = indices_from_subtensor (ilist , idx_list )
@@ -76,3 +82,30 @@ def incsubtensor(x, y, *ilist, torch_fn=torch_fn, idx_list=idx_list):
7682 return torch_fn (x , indices , y )
7783
7884 return incsubtensor
85+
86+
87+ @pytorch_funcify .register (AdvancedIncSubtensor )
88+ @pytorch_funcify .register (AdvancedIncSubtensor1 )
89+ def pytorch_funcify_AdvancedIncSubtensor (op , node , ** kwargs ):
90+ if getattr (op , "set_instead_of_inc" , False ):
91+
92+ def torch_fn (x , indices , y ):
93+ check_negative_steps (indices )
94+ x [indices ] = y
95+ return x
96+
97+ else :
98+
99+ def torch_fn (x , indices , y ):
100+ check_negative_steps (indices )
101+ x1 = x .clone ()
102+ x1 [indices ] += y
103+ return x1
104+
105+ def incsubtensor (x , y , * indices , torch_fn = torch_fn ):
106+ if len (indices ) == 1 :
107+ indices = indices [0 ]
108+
109+ return torch_fn (x , indices , y )
110+
111+ return incsubtensor
0 commit comments