Skip to content

Commit b38a01c

Browse files
Implement indexing operations in pytorch
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 1a1c62b commit b38a01c

File tree

6 files changed

+345
-9
lines changed

6 files changed

+345
-9
lines changed

pytensor/compile/mode.py

+1
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
471471
"BlasOpt",
472472
"fusion",
473473
"inplace",
474+
"local_uint_constant_indices",
474475
],
475476
),
476477
)

pytensor/link/pytorch/dispatch/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import pytensor.link.pytorch.dispatch.elemwise
88
import pytensor.link.pytorch.dispatch.math
99
import pytensor.link.pytorch.dispatch.extra_ops
10+
import pytensor.link.pytorch.dispatch.nlinalg
1011
import pytensor.link.pytorch.dispatch.shape
1112
import pytensor.link.pytorch.dispatch.sort
12-
import pytensor.link.pytorch.dispatch.nlinalg
13+
import pytensor.link.pytorch.dispatch.subtensor
1314
# isort: on

pytensor/link/pytorch/dispatch/basic.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,40 @@
11
from functools import singledispatch
22
from types import NoneType
33

4+
import numpy as np
45
import torch
56

67
from pytensor.compile.ops import DeepCopyOp
78
from pytensor.graph.fg import FunctionGraph
89
from pytensor.link.utils import fgraph_to_python
910
from pytensor.raise_op import CheckAndRaise
10-
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
11+
from pytensor.tensor.basic import (
12+
Alloc,
13+
AllocEmpty,
14+
ARange,
15+
Eye,
16+
Join,
17+
MakeVector,
18+
TensorFromScalar,
19+
)
1120

1221

1322
@singledispatch
14-
def pytorch_typify(data, dtype=None, **kwargs):
15-
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
23+
def pytorch_typify(data, **kwargs):
24+
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
25+
26+
27+
@pytorch_typify.register(np.ndarray)
28+
@pytorch_typify.register(torch.Tensor)
29+
def pytorch_typify_tensor(data, dtype=None, **kwargs):
1630
return torch.as_tensor(data, dtype=dtype)
1731

1832

33+
@pytorch_typify.register(slice)
1934
@pytorch_typify.register(NoneType)
20-
def pytorch_typify_None(data, **kwargs):
21-
return None
35+
@pytorch_typify.register(np.number)
36+
def pytorch_typify_no_conversion_needed(data, **kwargs):
37+
return data
2238

2339

2440
@singledispatch
@@ -132,3 +148,11 @@ def makevector(*x):
132148
return torch.tensor(x, dtype=torch_dtype)
133149

134150
return makevector
151+
152+
153+
@pytorch_funcify.register(TensorFromScalar)
154+
def pytorch_funcify_TensorFromScalar(op, **kwargs):
155+
def tensorfromscalar(x):
156+
return torch.as_tensor(x)
157+
158+
return tensorfromscalar
+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
2+
from pytensor.tensor.subtensor import (
3+
AdvancedIncSubtensor,
4+
AdvancedIncSubtensor1,
5+
AdvancedSubtensor,
6+
AdvancedSubtensor1,
7+
IncSubtensor,
8+
Subtensor,
9+
indices_from_subtensor,
10+
)
11+
from pytensor.tensor.type_other import MakeSlice, SliceType
12+
13+
14+
def check_negative_steps(indices):
15+
for index in indices:
16+
if isinstance(index, slice):
17+
if index.step is not None and index.step < 0:
18+
raise NotImplementedError(
19+
"Negative step sizes are not supported in Pytorch"
20+
)
21+
22+
23+
@pytorch_funcify.register(Subtensor)
24+
def pytorch_funcify_Subtensor(op, node, **kwargs):
25+
idx_list = op.idx_list
26+
27+
def subtensor(x, *flattened_indices):
28+
indices = indices_from_subtensor(flattened_indices, idx_list)
29+
check_negative_steps(indices)
30+
return x[indices]
31+
32+
return subtensor
33+
34+
35+
@pytorch_funcify.register(MakeSlice)
36+
def pytorch_funcify_makeslice(op, **kwargs):
37+
def makeslice(*x):
38+
return slice(x)
39+
40+
return makeslice
41+
42+
43+
@pytorch_funcify.register(AdvancedSubtensor1)
44+
@pytorch_funcify.register(AdvancedSubtensor)
45+
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
46+
def advsubtensor(x, *indices):
47+
check_negative_steps(indices)
48+
return x[indices]
49+
50+
return advsubtensor
51+
52+
53+
@pytorch_funcify.register(IncSubtensor)
54+
def pytorch_funcify_IncSubtensor(op, node, **kwargs):
55+
idx_list = op.idx_list
56+
inplace = op.inplace
57+
if op.set_instead_of_inc:
58+
59+
def set_subtensor(x, y, *flattened_indices):
60+
indices = indices_from_subtensor(flattened_indices, idx_list)
61+
check_negative_steps(indices)
62+
if not inplace:
63+
x = x.clone()
64+
x[indices] = y
65+
return x
66+
67+
return set_subtensor
68+
69+
else:
70+
71+
def inc_subtensor(x, y, *flattened_indices):
72+
indices = indices_from_subtensor(flattened_indices, idx_list)
73+
check_negative_steps(indices)
74+
if not inplace:
75+
x = x.clone()
76+
x[indices] += y
77+
return x
78+
79+
return inc_subtensor
80+
81+
82+
@pytorch_funcify.register(AdvancedIncSubtensor)
83+
@pytorch_funcify.register(AdvancedIncSubtensor1)
84+
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
85+
inplace = op.inplace
86+
ignore_duplicates = getattr(op, "ignore_duplicates", False)
87+
88+
if op.set_instead_of_inc:
89+
90+
def adv_set_subtensor(x, y, *indices):
91+
check_negative_steps(indices)
92+
if not inplace:
93+
x = x.clone()
94+
x[indices] = y.type_as(x)
95+
return x
96+
97+
return adv_set_subtensor
98+
99+
elif ignore_duplicates:
100+
101+
def adv_inc_subtensor_no_duplicates(x, y, *indices):
102+
check_negative_steps(indices)
103+
if not inplace:
104+
x = x.clone()
105+
x[indices] += y.type_as(x)
106+
return x
107+
108+
return adv_inc_subtensor_no_duplicates
109+
110+
else:
111+
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
112+
raise NotImplementedError(
113+
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
114+
)
115+
116+
def adv_inc_subtensor(x, y, *indices):
117+
# Not needed because slices aren't supported
118+
# check_negative_steps(indices)
119+
if not inplace:
120+
x = x.clone()
121+
x.index_put_(indices, y.type_as(x), accumulate=True)
122+
return x
123+
124+
return adv_inc_subtensor

tests/link/pytorch/test_basic.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ def compare_pytorch_and_py(
6666
py_res = pytensor_py_fn(*test_inputs)
6767

6868
if len(fgraph.outputs) > 1:
69-
for j, p in zip(pytorch_res, py_res):
70-
assert_fn(j.cpu(), p)
69+
for pytorch_res_i, py_res_i in zip(pytorch_res, py_res):
70+
assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i)
7171
else:
72-
assert_fn([pytorch_res[0].cpu()], py_res)
72+
assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])
7373

7474
return pytensor_torch_fn, pytorch_res
7575

0 commit comments

Comments
 (0)