Skip to content

Commit 320bac4

Browse files
Add initial support for PyTorch backend (#764)
1 parent efa845a commit 320bac4

File tree

11 files changed

+471
-2
lines changed

11 files changed

+471
-2
lines changed

.github/workflows/test.yml

+12-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ jobs:
7676
float32: [0, 1]
7777
install-numba: [0]
7878
install-jax: [0]
79+
install-torch: [0]
7980
part:
8081
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
8182
- "tests/scan"
@@ -116,6 +117,11 @@ jobs:
116117
fast-compile: 0
117118
float32: 0
118119
part: "tests/link/jax"
120+
- install-torch: 1
121+
python-version: "3.10"
122+
fast-compile: 0
123+
float32: 0
124+
part: "tests/link/pytorch"
119125
steps:
120126
- uses: actions/checkout@v4
121127
with:
@@ -142,9 +148,12 @@ jobs:
142148
- name: Install dependencies
143149
shell: micromamba-shell {0}
144150
run: |
151+
145152
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
146153
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
147154
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
155+
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
156+
148157
pip install -e ./
149158
micromamba list && pip freeze
150159
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
@@ -153,6 +162,7 @@ jobs:
153162
PYTHON_VERSION: ${{ matrix.python-version }}
154163
INSTALL_NUMBA: ${{ matrix.install-numba }}
155164
INSTALL_JAX: ${{ matrix.install-jax }}
165+
INSTALL_TORCH: ${{ matrix.install-torch}}
156166

157167
- name: Run tests
158168
shell: micromamba-shell {0}
@@ -199,7 +209,7 @@ jobs:
199209
- name: Install dependencies
200210
shell: micromamba-shell {0}
201211
run: |
202-
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
212+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
203213
pip install -e ./
204214
micromamba list && pip freeze
205215
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
@@ -268,3 +278,4 @@ jobs:
268278
directory: ./coverage/
269279
fail_ci_if_error: true
270280
token: ${{ secrets.CODECOV_TOKEN }}
281+

pytensor/compile/mode.py

+15
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytensor.link.c.basic import CLinker, OpWiseCLinker
2929
from pytensor.link.jax.linker import JAXLinker
3030
from pytensor.link.numba.linker import NumbaLinker
31+
from pytensor.link.pytorch.linker import PytorchLinker
3132
from pytensor.link.vm import VMLinker
3233

3334

@@ -47,6 +48,7 @@
4748
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
4849
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
4950
"jax": JAXLinker(),
51+
"pytorch": PytorchLinker(),
5052
"numba": NumbaLinker(),
5153
}
5254

@@ -460,6 +462,18 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
460462
],
461463
),
462464
)
465+
PYTORCH = Mode(
466+
PytorchLinker(),
467+
RewriteDatabaseQuery(
468+
include=["fast_run"],
469+
exclude=[
470+
"cxx_only",
471+
"BlasOpt",
472+
"fusion",
473+
"inplace",
474+
],
475+
),
476+
)
463477
NUMBA = Mode(
464478
NumbaLinker(),
465479
RewriteDatabaseQuery(
@@ -474,6 +488,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
474488
"FAST_RUN": FAST_RUN,
475489
"JAX": JAX,
476490
"NUMBA": NUMBA,
491+
"PYTORCH": PYTORCH,
477492
}
478493

479494
instantiated_default_mode = None

pytensor/link/basic.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,10 @@ def create_thunk_inputs(self, storage_map: dict[Variable, list[Any]]) -> list[An
600600
def jit_compile(self, fn: Callable) -> Callable:
601601
"""JIT compile a converted ``FunctionGraph``."""
602602

603+
def input_filter(self, inp: Any) -> Any:
604+
"""Apply a filter to the data input."""
605+
return inp
606+
603607
def output_filter(self, var: Variable, out: Any) -> Any:
604608
"""Apply a filter to the data output by a JITed function call."""
605609
return out
@@ -657,7 +661,7 @@ def thunk(
657661
thunk_inputs=thunk_inputs,
658662
thunk_outputs=thunk_outputs,
659663
):
660-
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
664+
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
661665

662666
for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
663667
compute_map[o_var][0] = True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# isort: off
2+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
3+
4+
# # Load dispatch specializations
5+
import pytensor.link.pytorch.dispatch.scalar
6+
import pytensor.link.pytorch.dispatch.elemwise
7+
# isort: on
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from functools import singledispatch
2+
3+
import torch
4+
5+
from pytensor.compile.ops import DeepCopyOp
6+
from pytensor.graph.fg import FunctionGraph
7+
from pytensor.link.utils import fgraph_to_python
8+
from pytensor.raise_op import CheckAndRaise
9+
10+
11+
@singledispatch
12+
def pytorch_typify(data, dtype=None, **kwargs):
13+
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
14+
return torch.as_tensor(data, dtype=dtype)
15+
16+
17+
@singledispatch
18+
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
19+
"""Create a PyTorch compatible function from an PyTensor `Op`."""
20+
raise NotImplementedError(
21+
f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation"
22+
)
23+
24+
25+
@pytorch_funcify.register(FunctionGraph)
26+
def pytorch_funcify_FunctionGraph(
27+
fgraph,
28+
node=None,
29+
fgraph_name="pytorch_funcified_fgraph",
30+
**kwargs,
31+
):
32+
return fgraph_to_python(
33+
fgraph,
34+
pytorch_funcify,
35+
type_conversion_fn=pytorch_typify,
36+
fgraph_name=fgraph_name,
37+
**kwargs,
38+
)
39+
40+
41+
@pytorch_funcify.register(CheckAndRaise)
42+
def pytorch_funcify_CheckAndRaise(op, **kwargs):
43+
error = op.exc_type
44+
msg = op.msg
45+
46+
def assert_fn(x, *conditions):
47+
for cond in conditions:
48+
if not cond.item():
49+
raise error(msg)
50+
return x
51+
52+
return assert_fn
53+
54+
55+
@pytorch_funcify.register(DeepCopyOp)
56+
def pytorch_funcify_DeepCopyOp(op, **kwargs):
57+
def deepcopyop(x):
58+
return x.clone()
59+
60+
return deepcopyop
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
5+
6+
7+
@pytorch_funcify.register(Elemwise)
8+
def pytorch_funcify_Elemwise(op, node, **kwargs):
9+
scalar_op = op.scalar_op
10+
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
11+
12+
def elemwise_fn(*inputs):
13+
Elemwise._check_runtime_broadcast(node, inputs)
14+
return base_fn(*inputs)
15+
16+
return elemwise_fn
17+
18+
19+
@pytorch_funcify.register(DimShuffle)
20+
def pytorch_funcify_DimShuffle(op, **kwargs):
21+
def dimshuffle(x):
22+
res = torch.permute(x, op.transposition)
23+
24+
shape = list(res.shape[: len(op.shuffle)])
25+
26+
for augm in op.augment:
27+
shape.insert(augm, 1)
28+
29+
res = torch.reshape(res, shape)
30+
31+
if not op.inplace:
32+
res = res.clone()
33+
34+
return res
35+
36+
return dimshuffle
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4+
from pytensor.scalar.basic import (
5+
ScalarOp,
6+
)
7+
8+
9+
@pytorch_funcify.register(ScalarOp)
10+
def pytorch_funcify_ScalarOp(op, node, **kwargs):
11+
"""Return pytorch function that implements the same computation as the Scalar Op.
12+
13+
This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
14+
even though it's dispatched on the Scalar Op.
15+
"""
16+
17+
nfunc_spec = getattr(op, "nfunc_spec", None)
18+
if nfunc_spec is None:
19+
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
20+
21+
func_name = nfunc_spec[0]
22+
23+
pytorch_func = getattr(torch, func_name)
24+
25+
if len(node.inputs) > op.nfunc_spec[1]:
26+
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
27+
# even though the base Op from `func_name` is specified as a binary Op.
28+
# This happens with `Add`, which can work as a `Sum` for multiple scalars.
29+
pytorch_variadic_func = getattr(torch, op.nfunc_variadic, None)
30+
if not pytorch_variadic_func:
31+
raise NotImplementedError(
32+
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs"
33+
)
34+
35+
def pytorch_func(*args):
36+
return pytorch_variadic_func(
37+
torch.stack(torch.broadcast_tensors(*args), axis=0), axis=0
38+
)
39+
40+
return pytorch_func

pytensor/link/pytorch/linker.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Any
2+
3+
from pytensor.graph.basic import Variable
4+
from pytensor.link.basic import JITLinker
5+
6+
7+
class PytorchLinker(JITLinker):
8+
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
9+
10+
def input_filter(self, inp: Any) -> Any:
11+
from pytensor.link.pytorch.dispatch import pytorch_typify
12+
13+
return pytorch_typify(inp)
14+
15+
def output_filter(self, var: Variable, out: Any) -> Any:
16+
return out.cpu()
17+
18+
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
19+
from pytensor.link.pytorch.dispatch import pytorch_funcify
20+
21+
return pytorch_funcify(
22+
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
23+
)
24+
25+
def jit_compile(self, fn):
26+
import torch
27+
28+
return torch.compile(fn)
29+
30+
def create_thunk_inputs(self, storage_map):
31+
thunk_inputs = []
32+
for n in self.fgraph.inputs:
33+
sinput = storage_map[n]
34+
thunk_inputs.append(sinput)
35+
36+
return thunk_inputs

tests/link/pytorch/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)