Skip to content

Commit c715295

Browse files
committed
Add Numba implementation of Blockwise
1 parent 18ba52c commit c715295

File tree

9 files changed

+272
-7
lines changed

9 files changed

+272
-7
lines changed

pytensor/link/numba/dispatch/__init__.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
33

44
# Load dispatch specializations
5-
import pytensor.link.numba.dispatch.scalar
6-
import pytensor.link.numba.dispatch.tensor_basic
5+
import pytensor.link.numba.dispatch.blockwise
6+
import pytensor.link.numba.dispatch.elemwise
77
import pytensor.link.numba.dispatch.extra_ops
88
import pytensor.link.numba.dispatch.nlinalg
99
import pytensor.link.numba.dispatch.random
10-
import pytensor.link.numba.dispatch.elemwise
1110
import pytensor.link.numba.dispatch.scan
12-
import pytensor.link.numba.dispatch.sparse
11+
import pytensor.link.numba.dispatch.scalar
1312
import pytensor.link.numba.dispatch.slinalg
13+
import pytensor.link.numba.dispatch.sparse
1414
import pytensor.link.numba.dispatch.subtensor
15+
import pytensor.link.numba.dispatch.tensor_basic
1516

1617
# isort: on
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import cast
2+
3+
from numba.core.extending import overload
4+
from numba.np.unsafe.ndarray import to_fixed_tuple
5+
6+
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
7+
from pytensor.link.numba.dispatch.vectorize_codegen import (
8+
_jit_options,
9+
_vectorized,
10+
encode_literals,
11+
store_core_outputs,
12+
)
13+
from pytensor.link.utils import compile_function_src
14+
from pytensor.tensor import TensorVariable, get_vector_length
15+
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
16+
17+
18+
@numba_funcify.register
19+
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
20+
[blockwise_node] = op.fgraph.apply_nodes
21+
blockwise_op: Blockwise = blockwise_node.op
22+
core_op = blockwise_op.core_op
23+
nin = len(blockwise_node.inputs)
24+
nout = len(blockwise_node.outputs)
25+
core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:])
26+
27+
core_node = blockwise_op._create_dummy_core_node(
28+
cast(tuple[TensorVariable], blockwise_node.inputs)
29+
)
30+
core_op_fn = numba_funcify(
31+
core_op,
32+
node=core_node,
33+
parent_node=node,
34+
fastmath=_jit_options["fastmath"],
35+
**kwargs,
36+
)
37+
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
38+
39+
batch_ndim = blockwise_op.batch_ndim(node)
40+
41+
# numba doesn't support nested literals right now...
42+
input_bc_patterns = encode_literals(
43+
tuple(inp.type.broadcastable[:batch_ndim] for inp in node.inputs[:nin])
44+
)
45+
output_bc_patterns = encode_literals(
46+
tuple(out.type.broadcastable[:batch_ndim] for out in node.outputs)
47+
)
48+
output_dtypes = encode_literals(tuple(out.type.dtype for out in node.outputs))
49+
inplace_pattern = encode_literals(())
50+
51+
# Numba does not allow a tuple generator in the Jitted function so we have to compile a helper to convert core_shapes into tuples
52+
# Alternatively, add an Op that converts shape vectors into tuples, like we did for JAX
53+
src = "def to_tuple(core_shapes): return ("
54+
for i in range(nout):
55+
src += f"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]}),"
56+
src += ")"
57+
58+
to_tuple = numba_njit(
59+
compile_function_src(
60+
src,
61+
"to_tuple",
62+
global_env={"to_fixed_tuple": to_fixed_tuple},
63+
)
64+
)
65+
66+
def blockwise_wrapper(*inputs_and_core_shapes):
67+
inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:]
68+
tuple_core_shapes = to_tuple(core_shapes)
69+
return _vectorized(
70+
core_op_fn,
71+
input_bc_patterns,
72+
output_bc_patterns,
73+
output_dtypes,
74+
inplace_pattern,
75+
(), # constant_inputs
76+
inputs,
77+
tuple_core_shapes,
78+
None, # size
79+
)
80+
81+
def blockwise(*inputs_and_core_shapes):
82+
raise NotImplementedError("Non-jitted BlockwiseWithCoreShape not implemented")
83+
84+
@overload(blockwise, jit_options=_jit_options)
85+
def ov_blockwise(*inputs_and_core_shapes):
86+
return blockwise_wrapper
87+
88+
return blockwise

pytensor/link/numba/dispatch/random.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def random_wrapper(core_shape, rng, size, *dist_params):
388388
return rng, draws
389389

390390
def random(core_shape, rng, size, *dist_params):
391-
pass
391+
raise NotImplementedError("Non-jitted random variable not implemented")
392392

393393
@overload(random, jit_options=_jit_options)
394394
def ov_random(core_shape, rng, size, *dist_params):

pytensor/tensor/blockwise.py

+8
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,11 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
442442

443443
class OpWithCoreShape(OpFromGraph):
444444
"""Generalizes an `Op` to include core shape as an additional input."""
445+
446+
447+
class BlockwiseWithCoreShape(OpWithCoreShape):
448+
"""Generalizes a Blockwise `Op` to include a core shape parameter."""
449+
450+
def __str__(self):
451+
[blockwise_node] = self.fgraph.apply_nodes
452+
return f"[{blockwise_node.op!s}]"

pytensor/tensor/random/rewriting/numba.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def introduce_explicit_core_shape_rv(fgraph, node):
1515
This core_shape is used by the numba backend to pre-allocate the output array.
1616
1717
If available, the core shape is extracted from the shape feature of the graph,
18-
which has a higher change of having been simplified, optimized, constant-folded.
18+
which has a higher chance of having been simplified, optimized, constant-folded.
1919
If missing, we fall back to the op._supp_shape_from_params method.
2020
2121
This rewrite is required for the numba backend implementation of RandomVariable.

pytensor/tensor/rewriting/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytensor.tensor.rewriting.jax
1010
import pytensor.tensor.rewriting.linalg
1111
import pytensor.tensor.rewriting.math
12+
import pytensor.tensor.rewriting.numba
1213
import pytensor.tensor.rewriting.ofg
1314
import pytensor.tensor.rewriting.shape
1415
import pytensor.tensor.rewriting.special

pytensor/tensor/rewriting/numba.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph import node_rewriter
3+
from pytensor.graph.basic import applys_between
4+
from pytensor.graph.rewriting.basic import out2in
5+
from pytensor.tensor.basic import as_tensor, constant
6+
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
7+
from pytensor.tensor.rewriting.shape import ShapeFeature
8+
9+
10+
@node_rewriter([Blockwise])
11+
def introduce_explicit_core_shape_blockwise(fgraph, node):
12+
"""Introduce the core shape of a Blockwise.
13+
14+
We wrap Blockwise graphs into a BlockwiseWithCoreShape OpFromGraph
15+
that has an extra "non-functional" input that represents the core shape of the Blockwise variable.
16+
This core_shape is used by the numba backend to pre-allocate the output array.
17+
18+
If available, the core shape is extracted from the shape feature of the graph,
19+
which has a higher change of having been simplified, optimized, constant-folded.
20+
If missing, we fall back to the op._supp_shape_from_params method.
21+
22+
This rewrite is required for the numba backend implementation of Blockwise.
23+
24+
Example
25+
-------
26+
27+
.. code-block:: python
28+
29+
import pytensor
30+
import pytensor.tensor as pt
31+
32+
x = pt.tensor("x", shape=(5, None, None))
33+
outs = pt.linalg.svd(x, compute_uv=True)
34+
pytensor.dprint(outs)
35+
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.0 [id A]
36+
# └─ x [id B]
37+
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.1 [id A]
38+
# └─ ···
39+
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.2 [id A]
40+
# └─ ···
41+
42+
# After the rewrite, note the new 3 core shape inputs
43+
fn = pytensor.function([x], outs, mode="NUMBA")
44+
fn.dprint(print_type=False)
45+
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].0 [id A] 6
46+
# ├─ x [id B]
47+
# ├─ MakeVector{dtype='int64'} [id C] 5
48+
# │ ├─ Shape_i{1} [id D] 2
49+
# │ │ └─ x [id B]
50+
# │ └─ Shape_i{1} [id D] 2
51+
# │ └─ ···
52+
# ├─ MakeVector{dtype='int64'} [id E] 4
53+
# │ └─ Minimum [id F] 3
54+
# │ ├─ Shape_i{1} [id D] 2
55+
# │ │ └─ ···
56+
# │ └─ Shape_i{2} [id G] 0
57+
# │ └─ x [id B]
58+
# └─ MakeVector{dtype='int64'} [id H] 1
59+
# ├─ Shape_i{2} [id G] 0
60+
# │ └─ ···
61+
# └─ Shape_i{2} [id G] 0
62+
# └─ ···
63+
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].1 [id A] 6
64+
# └─ ···
65+
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6
66+
# └─ ···
67+
"""
68+
op: Blockwise = node.op # type: ignore[annotation-unchecked]
69+
batch_ndim = op.batch_ndim(node)
70+
71+
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
72+
if shape_feature:
73+
core_shapes = [
74+
[shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)]
75+
for out in node.outputs
76+
]
77+
else:
78+
input_shapes = [tuple(inp.shape) for inp in node.inputs]
79+
core_shapes = [
80+
out_shape[batch_ndim:]
81+
for out_shape in op.infer_shape(None, node, input_shapes)
82+
]
83+
84+
core_shapes = [
85+
as_tensor(core_shape) if len(core_shape) else constant([], dtype="int64")
86+
for core_shape in core_shapes
87+
]
88+
89+
if any(
90+
isinstance(node.op, Blockwise)
91+
for node in applys_between(node.inputs, core_shapes)
92+
):
93+
# If Blockwise shows up in the shape graph we can't introduce the core shape
94+
return None
95+
96+
return BlockwiseWithCoreShape(
97+
[*node.inputs, *core_shapes],
98+
node.outputs,
99+
destroy_map=op.destroy_map,
100+
)(*node.inputs, *core_shapes, return_list=True)
101+
102+
103+
optdb.register(
104+
introduce_explicit_core_shape_blockwise.__name__,
105+
out2in(introduce_explicit_core_shape_blockwise),
106+
"numba",
107+
position=100,
108+
)

tests/link/numba/test_basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def compare_numba_and_py(
244244
Parameters
245245
----------
246246
fgraph
247-
`FunctionGraph` or inputs to compare.
247+
`FunctionGraph` or tuple(inputs, outputs) to compare.
248248
inputs
249249
Numeric inputs to be passed to the compiled graphs.
250250
assert_fn

tests/link/numba/test_blockwise.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor import function
5+
from pytensor.tensor import tensor
6+
from pytensor.tensor.basic import ARange
7+
from pytensor.tensor.blockwise import Blockwise
8+
from pytensor.tensor.nlinalg import SVD, Det
9+
from pytensor.tensor.slinalg import Cholesky, cholesky
10+
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
11+
12+
13+
# Fails if object mode warning is issued when not expected
14+
pytestmark = pytest.mark.filterwarnings("error")
15+
16+
17+
@pytest.mark.parametrize("shape_opt", [True, False], ids=str)
18+
@pytest.mark.parametrize("core_op", [Det(), Cholesky(), SVD(compute_uv=True)], ids=str)
19+
def test_blockwise(core_op, shape_opt):
20+
x = tensor(shape=(5, None, None))
21+
outs = Blockwise(core_op=core_op)(x, return_list=True)
22+
23+
mode = (
24+
numba_mode.including("ShapeOpt")
25+
if shape_opt
26+
else numba_mode.excluding("ShapeOpt")
27+
)
28+
x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
29+
compare_numba_and_py(
30+
([x], outs),
31+
[x_test],
32+
numba_mode=mode,
33+
eval_obj_mode=False,
34+
)
35+
36+
37+
def test_non_square_blockwise():
38+
"""Test that Op that cannot always be blockwised at runtime fails gracefully."""
39+
x = tensor(shape=(3,), dtype="int64")
40+
out = Blockwise(core_op=ARange(dtype="int64"), signature="(),(),()->(a)")(0, x, 1)
41+
42+
with pytest.warns(UserWarning, match="Numba will use object mode"):
43+
fn = function([x], out, mode="NUMBA")
44+
45+
np.testing.assert_allclose(fn([5, 5, 5]), np.broadcast_to(np.arange(5), (3, 5)))
46+
47+
with pytest.raises(ValueError):
48+
fn([3, 4, 5])
49+
50+
51+
def test_blockwise_benchmark(benchmark):
52+
x = tensor(shape=(5, 3, 3))
53+
out = cholesky(x)
54+
assert isinstance(out.owner.op, Blockwise)
55+
56+
fn = function([x], out, mode="NUMBA")
57+
x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
58+
fn(x_test) # JIT compile
59+
benchmark(fn, x_test)

0 commit comments

Comments
 (0)