1
1
from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
2
2
from pytensor .tensor .subtensor import (
3
+ AdvancedIncSubtensor ,
3
4
AdvancedIncSubtensor1 ,
4
5
AdvancedSubtensor ,
5
6
AdvancedSubtensor1 ,
10
11
from pytensor .tensor .type_other import MakeSlice
11
12
12
13
14
+ def check_negative_steps (index ):
15
+ for i in index :
16
+ if isinstance (i , slice ):
17
+ if i .step and isinstance (i .step , int ) and i .step < 0 :
18
+ raise NotImplementedError (
19
+ "Negative step sizes are not supported in Pytorch"
20
+ )
21
+ elif i .step and not isinstance (i .step , int ):
22
+ raise NotImplementedError (
23
+ "Negative step sizes are not supported in Pytorch"
24
+ )
25
+
26
+
13
27
@pytorch_funcify .register (Subtensor )
14
28
def pytorch_funcify_Subtensor (op , node , ** kwargs ):
15
29
idx_list = getattr (op , "idx_list" , None )
16
30
17
31
def subtensor (x , * ilists ):
18
32
indices = indices_from_subtensor (ilists , idx_list )
19
- for i in indices :
20
- if isinstance (i , slice ):
21
- if i .step and i .step < 0 :
22
- raise NotImplementedError (
23
- "Negative step sizes are not supported in Pytorch"
24
- )
33
+ check_negative_steps (indices )
25
34
26
35
return x [indices ]
27
36
@@ -32,12 +41,7 @@ def subtensor(x, *ilists):
32
41
@pytorch_funcify .register (AdvancedSubtensor )
33
42
def pytorch_funcify_AdvSubtensor (op , node , ** kwargs ):
34
43
def advsubtensor (x , * indices ):
35
- for i in indices :
36
- if isinstance (i , slice ):
37
- if i .step and i .step < 0 :
38
- raise NotImplementedError (
39
- "Negative step sizes are not supported in Pytorch"
40
- )
44
+ check_negative_steps (indices )
41
45
return x [indices ]
42
46
43
47
return advsubtensor
@@ -52,21 +56,23 @@ def makeslice(*x):
52
56
53
57
54
58
@pytorch_funcify .register (IncSubtensor )
55
- @pytorch_funcify .register (AdvancedIncSubtensor1 )
56
59
def pytorch_funcify_IncSubtensor (op , node , ** kwargs ):
57
60
idx_list = getattr (op , "idx_list" , None )
58
61
59
62
if getattr (op , "set_instead_of_inc" , False ):
60
63
61
64
def torch_fn (x , indices , y ):
65
+ check_negative_steps (indices )
62
66
x [indices ] = y
63
67
return x
64
68
65
69
else :
66
70
67
71
def torch_fn (x , indices , y ):
68
- x [indices ] += y
69
- return x
72
+ check_negative_steps (indices )
73
+ x1 = x .clone ()
74
+ x1 [indices ] += y
75
+ return x1
70
76
71
77
def incsubtensor (x , y , * ilist , torch_fn = torch_fn , idx_list = idx_list ):
72
78
indices = indices_from_subtensor (ilist , idx_list )
@@ -76,3 +82,30 @@ def incsubtensor(x, y, *ilist, torch_fn=torch_fn, idx_list=idx_list):
76
82
return torch_fn (x , indices , y )
77
83
78
84
return incsubtensor
85
+
86
+
87
+ @pytorch_funcify .register (AdvancedIncSubtensor )
88
+ @pytorch_funcify .register (AdvancedIncSubtensor1 )
89
+ def pytorch_funcify_AdvancedIncSubtensor (op , node , ** kwargs ):
90
+ if getattr (op , "set_instead_of_inc" , False ):
91
+
92
+ def torch_fn (x , indices , y ):
93
+ check_negative_steps (indices )
94
+ x [indices ] = y
95
+ return x
96
+
97
+ else :
98
+
99
+ def torch_fn (x , indices , y ):
100
+ check_negative_steps (indices )
101
+ x1 = x .clone ()
102
+ x1 [indices ] += y
103
+ return x1
104
+
105
+ def incsubtensor (x , y , * indices , torch_fn = torch_fn ):
106
+ if len (indices ) == 1 :
107
+ indices = indices [0 ]
108
+
109
+ return torch_fn (x , indices , y )
110
+
111
+ return incsubtensor
0 commit comments