From a77c9eebb10918aabc33e4e7e613382057f6277e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 7 Jun 2024 18:13:46 +0200 Subject: [PATCH 01/15] Add __repr__ for Numba and JAX linkers --- pytensor/link/jax/linker.py | 3 +++ pytensor/link/numba/linker.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 06370b4514..5e1a87e31e 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -87,3 +87,6 @@ def create_thunk_inputs(self, storage_map): thunk_inputs.append(sinput) return thunk_inputs + + def __repr__(self): + return "JAXLinker()" diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 553c5ef217..3cf542a316 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -35,3 +35,6 @@ def create_thunk_inputs(self, storage_map): thunk_inputs.append(sinput) return thunk_inputs + + def __repr__(self): + return "NumbaLinker()" From 81988be653f26ba49ccafbbc43d42ce5b96fbd7d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 7 Jun 2024 19:04:51 +0200 Subject: [PATCH 02/15] Simplify Numba implementation of Alloc --- pytensor/link/numba/dispatch/tensor_basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 8f5972c058..30e89c4256 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -68,7 +68,7 @@ def numba_funcify_Alloc(op, node, **kwargs): shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( - f"{item_name} = to_scalar({shape_name})" + f"{item_name} = {shape_name}.item()" for item_name, shape_name in zip( shape_var_item_names, shape_var_names, strict=True ) @@ -86,7 +86,7 @@ def numba_funcify_Alloc(op, node, **kwargs): alloc_def_src = f""" def alloc(val, {", ".join(shape_var_names)}): - val_np = np.asarray(val) + val_np = val {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} {check_runtime_broadcast_src} From c62d733cb9e4e1aecbc885897363e9479ee72752 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 7 Jun 2024 19:05:19 +0200 Subject: [PATCH 03/15] Add error message in Numba implementation of SpecifyShape --- pytensor/link/numba/dispatch/basic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 2f3cac6ea6..55cebc2d60 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -481,11 +481,11 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] func_conditions = [ - f"assert x.shape[{i}] == {shape_input_names}" - for i, (shape_input, shape_input_names) in enumerate( + f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'" + for i, (node_dim_input, eval_dim_name) in enumerate( zip(shape_inputs, shape_input_names, strict=True) ) - if shape_input is not NoneConst + if node_dim_input is not NoneConst ] func = dedent( From dae7ff3d63b1d8d68f47e78af3f72d16d020a419 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jun 2024 16:25:42 +0200 Subject: [PATCH 04/15] Add support for TypedList in numba backend --- pytensor/link/numba/dispatch/basic.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 55cebc2d60..f961524fa0 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -40,6 +40,7 @@ from pytensor.tensor.slinalg import Solve from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import MakeSlice, NoneConst +from pytensor.typed_list import TypedListType def global_numba_func(func): @@ -135,6 +136,8 @@ def get_numba_type( return CSCMatrixType(numba_dtype) raise NotImplementedError() + elif isinstance(pytensor_type, TypedListType): + return numba.types.List(get_numba_type(pytensor_type.ttype)) else: raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") From 8fe34a174c569014d13081f0e3337097731d1719 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 7 Jun 2024 18:06:58 +0200 Subject: [PATCH 05/15] Try to run full test suite in Numba backend --- .github/workflows/test.yml | 2 +- pytensor/compile/mode.py | 33 +++++++++++++++++++++------------ pytensor/configdefaults.py | 16 +++++++++++++--- 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9677615206..48ada81faf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -193,7 +193,7 @@ jobs: else micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; fi - if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi + micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi pip install pytest-sphinx diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index ffa27e5d5a..adf9732f5e 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -63,9 +63,8 @@ def register_linker(name, linker): # If a string is passed as the optimizer argument in the constructor # for Mode, it will be used as the key to retrieve the real optimizer # in this dictionary -exclude = [] -if not config.cxx: - exclude = ["cxx_only"] + +exclude = ["cxx_only", "BlasOpt"] OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude) # Even if multiple merge optimizer call will be there, this shouldn't # impact performance. @@ -346,6 +345,11 @@ def __setstate__(self, state): optimizer = predefined_optimizers[optimizer] if isinstance(optimizer, RewriteDatabaseQuery): self.provided_optimizer = optimizer + + # Force numba-required rewrites if using NumbaLinker + if isinstance(linker, NumbaLinker): + optimizer = optimizer.including("numba") + self._optimizer = optimizer self.call_time = 0 self.fn_time = 0 @@ -443,16 +447,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): # string as the key # Use VM_linker to allow lazy evaluation by default. FAST_COMPILE = Mode( - VMLinker(use_cloop=False, c_thunks=False), - RewriteDatabaseQuery(include=["fast_compile", "py_only"]), + NumbaLinker(), + # TODO: Fast_compile should just use python code, CHANGE ME! + RewriteDatabaseQuery( + include=["fast_compile", "numba"], + exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], + ), +) +FAST_RUN = Mode( + NumbaLinker(), + RewriteDatabaseQuery( + include=["fast_run", "numba"], + exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], + ), ) -if config.cxx: - FAST_RUN = Mode("cvm", "fast_run") -else: - FAST_RUN = Mode( - "vm", - RewriteDatabaseQuery(include=["fast_run", "py_only"]), - ) NUMBA = Mode( NumbaLinker(), @@ -565,6 +573,7 @@ def register_mode(name, mode): Add a `Mode` which can be referred to by `name` in `function`. """ + # TODO: Remove me if name in predefined_modes: raise ValueError(f"Mode name already taken: {name}") predefined_modes[name] = mode diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index ca3c44bf6d..9c8fecab33 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -370,11 +370,21 @@ def add_compile_configvars(): if rc == 0 and config.cxx != "": # Keep the default linker the same as the one for the mode FAST_RUN - linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"] + linker_options = [ + "cvm", + "c|py", + "py", + "c", + "c|py_nogc", + "vm", + "vm_nogc", + "cvm_nogc", + "jax", + ] else: # g++ is not present or the user disabled it, # linker should default to python only. - linker_options = ["py", "vm_nogc"] + linker_options = ["py", "vm", "vm_nogc", "jax"] if type(config).cxx.is_default: # If the user provided an empty value for cxx, do not warn. _logger.warning( @@ -388,7 +398,7 @@ def add_compile_configvars(): "linker", "Default linker used if the pytensor flags mode is Mode", # Not mutable because the default mode is cached after the first use. - EnumStr("cvm", linker_options, mutable=False), + EnumStr("numba", linker_options, mutable=False), in_c_key=False, ) From 1e88837b28960e44132be2d5187e707ca28de6f0 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 7 Jun 2024 19:05:35 +0200 Subject: [PATCH 06/15] Change error message to that raised from Numba --- tests/tensor/rewriting/test_basic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 1730ae46ac..a9a1cecfb7 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1,4 +1,5 @@ import copy +import re import numpy as np import pytest @@ -306,7 +307,9 @@ def test_inconsistent_shared(self, shape_unsafe): # Error raised by Alloc Op with pytest.raises( ValueError, - match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)", + match=re.escape( + "cannot assign slice of shape (3, 7) from input of shape (6, 7)" + ), ): f() From 5bacff2ba2640f5131548423ff84073129901197 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 7 Jun 2024 19:21:25 +0200 Subject: [PATCH 07/15] XFAIL float16 test --- tests/tensor/rewriting/test_basic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index a9a1cecfb7..6f04b61506 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1206,6 +1206,7 @@ def test_sum_bool_upcast(self): f(5) +@pytest.mark.xfail(reason="Numba does not support float16") class TestLocalOptAllocF16(TestLocalOptAlloc): dtype = "float16" From 4fcb060c85144c2eb6ccf8bf6932f5828ad2b891 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jun 2024 16:05:43 +0200 Subject: [PATCH 08/15] XFAIL conv tests of Ops without Python implementation Mark overly specific tests as xfail --- tests/tensor/conv/test_abstract_conv.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/tensor/conv/test_abstract_conv.py b/tests/tensor/conv/test_abstract_conv.py index 23ba23e1e9..814f1eb80b 100644 --- a/tests/tensor/conv/test_abstract_conv.py +++ b/tests/tensor/conv/test_abstract_conv.py @@ -948,9 +948,9 @@ def run_gradinput( ) -@pytest.mark.skipif( - config.cxx == "", - reason="SciPy and cxx needed", +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") +@pytest.mark.xfail( + reason="Involves Ops with no Python implementation for numba to use as fallback" ) class TestAbstractConvNoOptim(BaseTestConv2d): @classmethod @@ -1884,9 +1884,9 @@ def test_conv2d_grad_wrt_weights(self): ) -@pytest.mark.skipif( - config.cxx == "", - reason="SciPy and cxx needed", +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") +@pytest.mark.xfail( + reason="Involves Ops with no Python implementation for numba to use as fallback" ) class TestGroupedConvNoOptim: conv = abstract_conv.AbstractConv2d @@ -2096,9 +2096,9 @@ def conv_gradinputs(filters_val, output_val): utt.verify_grad(conv_gradinputs, [kern, top], mode=self.mode, eps=1) -@pytest.mark.skipif( - config.cxx == "", - reason="SciPy and cxx needed", +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") +@pytest.mark.xfail( + reason="Involves Ops with no Python implementation for numba to use as fallback" ) class TestGroupedConv3dNoOptim(TestGroupedConvNoOptim): conv = abstract_conv.AbstractConv3d From e8e103d0d9cc7c526131f407974b4b7eb4cfdabd Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 7 Jun 2024 20:01:26 +0200 Subject: [PATCH 09/15] Refactor test and change expected counts of Alloc that were due to BlasOpt --- tests/tensor/test_basic.py | 56 ++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index dee0023efd..97a2631e01 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -758,41 +758,43 @@ def check_allocs_in_fgraph(fgraph, n): def setup_method(self): self.rng = np.random.default_rng(seed=utt.fetch_seed()) - def test_alloc_constant_folding(self): + @pytest.mark.parametrize( + "subtensor_fn, expected_grad_n_alloc", + [ + # IncSubtensor1 + (lambda x: x[:60], 1), + # AdvancedIncSubtensor1 + (lambda x: x[np.arange(60)], 1), + # AdvancedIncSubtensor + (lambda x: x[np.arange(50), np.arange(50)], 1), + ], + ) + def test_alloc_constant_folding(self, subtensor_fn, expected_grad_n_alloc): test_params = np.asarray(self.rng.standard_normal(50 * 60), self.dtype) some_vector = vector("some_vector", dtype=self.dtype) some_matrix = some_vector.reshape((60, 50)) variables = self.shared(np.ones((50,), dtype=self.dtype)) - idx = constant(np.arange(50)) - - for alloc_, (subtensor, n_alloc) in zip( - self.allocs, - [ - # IncSubtensor1 - (some_matrix[:60], 2), - # AdvancedIncSubtensor1 - (some_matrix[arange(60)], 2), - # AdvancedIncSubtensor - (some_matrix[idx, idx], 1), - ], - strict=True, - ): - derp = pt_sum(dense_dot(subtensor, variables)) - fobj = pytensor.function([some_vector], derp, mode=self.mode) - grad_derp = pytensor.grad(derp, some_vector) - fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode) + subtensor = subtensor_fn(some_matrix) - topo_obj = fobj.maker.fgraph.toposort() - assert sum(isinstance(node.op, type(alloc_)) for node in topo_obj) == 0 + derp = pt_sum(dense_dot(subtensor, variables)) + fobj = pytensor.function([some_vector], derp, mode=self.mode) + assert ( + sum(isinstance(node.op, Alloc) for node in fobj.maker.fgraph.apply_nodes) + == 0 + ) + # TODO: Assert something about the value if we bothered to call it? + fobj(test_params) - topo_grad = fgrad.maker.fgraph.toposort() - assert ( - sum(isinstance(node.op, type(alloc_)) for node in topo_grad) == n_alloc - ), (alloc_, subtensor, n_alloc, topo_grad) - fobj(test_params) - fgrad(test_params) + grad_derp = pytensor.grad(derp, some_vector) + fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode) + assert ( + sum(isinstance(node.op, Alloc) for node in fgrad.maker.fgraph.apply_nodes) + == expected_grad_n_alloc + ) + # TODO: Assert something about the value if we bothered to call it? + fgrad(test_params) def test_alloc_output(self): val = constant(self.rng.standard_normal((1, 1)), dtype=self.dtype) From 99de2b05f3ea0cea4ec3801cdaf5445aad3aedd9 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jun 2024 15:09:28 +0200 Subject: [PATCH 10/15] Temporary patch for https://github.com/numba/numba/issues/9554 --- pytensor/link/numba/dispatch/scalar.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index e9b637b00f..31904eac04 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -23,6 +23,7 @@ Composite, Identity, Mul, + Pow, Reciprocal, ScalarOp, Second, @@ -154,6 +155,21 @@ def numba_funcify_Switch(op, node, **kwargs): return numba_basic.global_numba_func(switch) +@numba_funcify.register(Pow) +def numba_funcify_Pow(op, node, **kwargs): + pow_dtype = node.inputs[1].type.dtype + + def pow(x, y): + return x**y + + # Work-around https://github.com/numba/numba/issues/9554 + # fast-math casuse kernel crash + patch_kwargs = {} + if pow_dtype.startswith("int"): + patch_kwargs["fastmath"] = False + return numba_basic.numba_njit(**patch_kwargs)(pow) + + def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str): """Create a Numba-compatible N-ary function from a binary function.""" unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_") From 17ef6a4c401a9e5f090848539f0bd5377037fe24 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jun 2024 16:10:24 +0200 Subject: [PATCH 11/15] Allow tests to pass on object mode --- tests/tensor/test_basic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 97a2631e01..dfe9f2630f 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -150,7 +150,10 @@ ) -pytestmark = pytest.mark.filterwarnings("error") +pytestmark = pytest.mark.filterwarnings( + "error", + "ignore:Numba will use object mode:UserWarning", +) if config.mode == "FAST_COMPILE": mode_opt = "FAST_RUN" From 59078746c11d15f7d104dc3784cc51353890050b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 28 Mar 2025 15:32:53 +0100 Subject: [PATCH 12/15] .hacks --- pytensor/compile/mode.py | 1 + pytensor/link/numba/dispatch/basic.py | 22 +++++++++--------- pytensor/link/numba/dispatch/elemwise.py | 8 +++++-- pytensor/link/numba/dispatch/scalar.py | 8 +++++++ pytensor/link/numba/dispatch/subtensor.py | 23 +++++++++++++++++++ .../link/numba/dispatch/vectorize_codegen.py | 10 ++++++++ 6 files changed, 59 insertions(+), 13 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index adf9732f5e..73f5c66efa 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -507,6 +507,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): predefined_modes = { "FAST_COMPILE": FAST_COMPILE, "FAST_RUN": FAST_RUN, + "OLD_FAST_RUN": Mode("cvm", "fast_run"), "JAX": JAX, "NUMBA": NUMBA, "PYTORCH": PYTORCH, diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index f961524fa0..ed0f5670f4 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -12,7 +12,7 @@ import scipy.special from llvmlite import ir from numba import types -from numba.core.errors import NumbaWarning, TypingError +from numba.core.errors import TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.extending import box, overload @@ -71,16 +71,16 @@ def numba_njit(*args, fastmath=None, **kwargs): # Suppress cache warning for internal functions # We have to add an ansi escape code for optional bold text by numba - warnings.filterwarnings( - "ignore", - message=( - "(\x1b\\[1m)*" # ansi escape code for bold text - "Cannot cache compiled function " - '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" ' - "as it uses dynamic globals" - ), - category=NumbaWarning, - ) + # warnings.filterwarnings( + # "ignore", + # message=( + # "(\x1b\\[1m)*" # ansi escape code for bold text + # "Cannot cache compiled function " + # '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" ' + # "as it uses dynamic globals" + # ), + # category=NumbaWarning, + # ) if len(args) > 0 and callable(args[0]): return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9fd81dadcf..bcd011fb67 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -16,7 +16,6 @@ _jit_options, _vectorized, encode_literals, - store_core_outputs, ) from pytensor.link.utils import compile_function_src from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple @@ -276,7 +275,12 @@ def numba_funcify_Elemwise(op, node, **kwargs): nin = len(node.inputs) nout = len(node.outputs) - core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) + # core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) + if isinstance(op.scalar_op, Mul) and len(node.inputs) == 2: + + @numba_njit + def core_op_fn(x, y, out): + out[...] = x * y input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs) output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 31904eac04..018a2bd5a0 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -196,6 +196,14 @@ def numba_funcify_Add(op, node, **kwargs): @numba_funcify.register(Mul) def numba_funcify_Mul(op, node, **kwargs): + if len(node.inputs) == 2: + + @numba_basic.numba_njit + def binary_mul(x, y): + return x * y + + return binary_mul + signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index ee9e183d16..a5db84106e 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -13,6 +13,7 @@ AdvancedSubtensor1, IncSubtensor, Subtensor, + get_idx_list, ) from pytensor.tensor.type_other import NoneTypeT, SliceType @@ -95,6 +96,9 @@ def {function_name}({", ".join(input_names)}): return np.asarray(z) """ + print() + node.dprint(depth=2, print_type=True) + print("subtensor_def_src:", subtensor_def_src) func = compile_function_src( subtensor_def_src, function_name=function_name, @@ -103,6 +107,25 @@ def {function_name}({", ".join(input_names)}): return numba_njit(func, boundscheck=True) +@numba_funcify.register(Subtensor) +def numba_funcify_subtensor_custom(op, node, **kwargs): + idxs = get_idx_list(node.inputs, op.idx_list) + + if ( + idxs + and not isinstance(idxs[0], slice) + and all(idx == slice(None) for idx in idxs[1:]) + ): + + @numba_njit + def scalar_subtensor_leading_dim(x, idx): + return x[idx] + + return scalar_subtensor_leading_dim + + return numba_funcify_default_subtensor(op, node, **kwargs) + + @numba_funcify.register(AdvancedSubtensor) @numba_funcify.register(AdvancedIncSubtensor) def numba_funcify_AdvancedSubtensor(op, node, **kwargs): diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 74870e29bd..7052612e6e 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -35,6 +35,16 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): on[...] = ton """ + if nin == 2 and nout == 1: + + @numba_basic.numba_njit + def store_core_outputs_2in1out(i0, i1, o0): + t0 = core_op_fn(i0, i1) + o0[...] = t0 + + return store_core_outputs_2in1out + print(nin, nout) + inputs = [f"i{i}" for i in range(nin)] outputs = [f"o{i}" for i in range(nout)] inner_outputs = [f"t{output}" for output in outputs] From ed84a8a7ebfe189b2a3aa6336ff4ecb7c73a6db0 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 28 Mar 2025 15:42:59 +0100 Subject: [PATCH 13/15] .new_test --- tests/tensor/rewriting/test_basic.py | 115 +++++++++++++++------------ 1 file changed, 62 insertions(+), 53 deletions(-) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 6f04b61506..93635aeb4a 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1,5 +1,4 @@ import copy -import re import numpy as np import pytest @@ -143,55 +142,68 @@ def rewrite(g, level="fast_run"): return g -def test_local_useless_slice(): - # test a simple matrix - x = matrix("x") - mode_excluding = get_default_mode().excluding( - "local_useless_slice", "local_mul_canonizer" - ) - mode_including = ( - get_default_mode() - .including("local_useless_slice") - .excluding("local_mul_canonizer") - ) +class TestME: + def local_useless_slice_tester(self): + # test a simple matrix + x = matrix("x") + mode_excluding = get_mode("NUMBA").excluding( + "local_useless_slice", "local_mul_canonizer" + ) + mode_including = ( + get_mode("NUMBA") + .including("local_useless_slice") + .excluding("local_mul_canonizer") + ) - # test with and without the useless slice - o = 2 * x[0, :] - f_excluding = function([x], o, mode=mode_excluding) - f_including = function([x], o, mode=mode_including) - rng = np.random.default_rng(utt.fetch_seed()) - test_inp = rng.integers(-10, 10, (4, 4)).astype("float32") - assert all(f_including(test_inp) == f_excluding(test_inp)) - # test to see if the slice is truly gone - apply_node = f_including.maker.fgraph.toposort()[0] - subtens = apply_node.op - assert not any(isinstance(idx, slice) for idx in subtens.idx_list) - - # Now test that the stack trace is copied over properly, - # before before and after rewriting. - assert check_stack_trace(f_excluding, ops_to_check="all") - assert check_stack_trace(f_including, ops_to_check="all") - - # test a 4d tensor - z = tensor4("z") - o2 = z[1, :, :, 1] - o3 = z[0, :, :, :] - f_including_check = function([z], o2, mode=mode_including) - f_including_check_apply = function([z], o3, mode=mode_including) - - # The rewrite shouldn't apply here - apply_node = f_including_check.maker.fgraph.toposort()[0] - subtens = apply_node.op - assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2 - # But it should here - apply_node = f_including_check_apply.maker.fgraph.toposort()[0] - subtens = apply_node.op - assert not any(isinstance(idx, slice) for idx in subtens.idx_list) - - # Finally, test that the stack trace is copied over properly, - # before before and after rewriting. - assert check_stack_trace(f_including_check, ops_to_check=Subtensor) - assert check_stack_trace(f_including_check_apply, ops_to_check=Subtensor) + # test with and without the useless slice + o = 2 * x[0, :] + f_excluding = function([x], o, mode=mode_excluding) + f_including = function([x], o, mode=mode_including) + rng = np.random.default_rng(utt.fetch_seed()) + test_inp = rng.integers(-10, 10, (4, 4)).astype("float32") + assert all(f_including(test_inp) == f_excluding(test_inp)) + # test to see if the slice is truly gone + apply_node = f_including.maker.fgraph.toposort()[0] + subtens = apply_node.op + assert not any(isinstance(idx, slice) for idx in subtens.idx_list) + + # Now test that the stack trace is copied over properly, + # before before and after rewriting. + assert check_stack_trace(f_excluding, ops_to_check="all") + assert check_stack_trace(f_including, ops_to_check="all") + + # test a 4d tensor + z = tensor4("z") + o2 = z[1, :, :, 1] + o3 = z[0, :, :, :] + f_including_check = function([z], o2, mode=mode_including) + f_including_check_apply = function([z], o3, mode=mode_including) + + # The rewrite shouldn't apply here + apply_node = f_including_check.maker.fgraph.toposort()[0] + subtens = apply_node.op + assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2 + # But it should here + apply_node = f_including_check_apply.maker.fgraph.toposort()[0] + subtens = apply_node.op + assert not any(isinstance(idx, slice) for idx in subtens.idx_list) + + # Finally, test that the stack trace is copied over properly, + # before before and after rewriting. + assert check_stack_trace(f_including_check, ops_to_check=Subtensor) + assert check_stack_trace(f_including_check_apply, ops_to_check=Subtensor) + + def test_t0(self): + import pytensor + + x = pt.vector("x") + pytensor.function([x], x + 1) + + def test_t1(self): + self.local_useless_slice_tester() + + def test_t2(self): + self.local_useless_slice_tester() def test_local_useless_fill(): @@ -307,9 +319,7 @@ def test_inconsistent_shared(self, shape_unsafe): # Error raised by Alloc Op with pytest.raises( ValueError, - match=re.escape( - "cannot assign slice of shape (3, 7) from input of shape (6, 7)" - ), + match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)", ): f() @@ -1206,7 +1216,6 @@ def test_sum_bool_upcast(self): f(5) -@pytest.mark.xfail(reason="Numba does not support float16") class TestLocalOptAllocF16(TestLocalOptAlloc): dtype = "float16" From 32311ba6ca26fb5511bf1ed13b3f91d31182f63b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 28 Mar 2025 15:45:03 +0100 Subject: [PATCH 14/15] Revert ".hacks" This reverts commit 59078746c11d15f7d104dc3784cc51353890050b. --- pytensor/compile/mode.py | 1 - pytensor/link/numba/dispatch/basic.py | 22 +++++++++--------- pytensor/link/numba/dispatch/elemwise.py | 8 ++----- pytensor/link/numba/dispatch/scalar.py | 8 ------- pytensor/link/numba/dispatch/subtensor.py | 23 ------------------- .../link/numba/dispatch/vectorize_codegen.py | 10 -------- 6 files changed, 13 insertions(+), 59 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 73f5c66efa..adf9732f5e 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -507,7 +507,6 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): predefined_modes = { "FAST_COMPILE": FAST_COMPILE, "FAST_RUN": FAST_RUN, - "OLD_FAST_RUN": Mode("cvm", "fast_run"), "JAX": JAX, "NUMBA": NUMBA, "PYTORCH": PYTORCH, diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index ed0f5670f4..f961524fa0 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -12,7 +12,7 @@ import scipy.special from llvmlite import ir from numba import types -from numba.core.errors import TypingError +from numba.core.errors import NumbaWarning, TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.extending import box, overload @@ -71,16 +71,16 @@ def numba_njit(*args, fastmath=None, **kwargs): # Suppress cache warning for internal functions # We have to add an ansi escape code for optional bold text by numba - # warnings.filterwarnings( - # "ignore", - # message=( - # "(\x1b\\[1m)*" # ansi escape code for bold text - # "Cannot cache compiled function " - # '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" ' - # "as it uses dynamic globals" - # ), - # category=NumbaWarning, - # ) + warnings.filterwarnings( + "ignore", + message=( + "(\x1b\\[1m)*" # ansi escape code for bold text + "Cannot cache compiled function " + '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" ' + "as it uses dynamic globals" + ), + category=NumbaWarning, + ) if len(args) > 0 and callable(args[0]): return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index bcd011fb67..9fd81dadcf 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -16,6 +16,7 @@ _jit_options, _vectorized, encode_literals, + store_core_outputs, ) from pytensor.link.utils import compile_function_src from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple @@ -275,12 +276,7 @@ def numba_funcify_Elemwise(op, node, **kwargs): nin = len(node.inputs) nout = len(node.outputs) - # core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) - if isinstance(op.scalar_op, Mul) and len(node.inputs) == 2: - - @numba_njit - def core_op_fn(x, y, out): - out[...] = x * y + core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs) output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 018a2bd5a0..31904eac04 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -196,14 +196,6 @@ def numba_funcify_Add(op, node, **kwargs): @numba_funcify.register(Mul) def numba_funcify_Mul(op, node, **kwargs): - if len(node.inputs) == 2: - - @numba_basic.numba_njit - def binary_mul(x, y): - return x * y - - return binary_mul - signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index a5db84106e..ee9e183d16 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -13,7 +13,6 @@ AdvancedSubtensor1, IncSubtensor, Subtensor, - get_idx_list, ) from pytensor.tensor.type_other import NoneTypeT, SliceType @@ -96,9 +95,6 @@ def {function_name}({", ".join(input_names)}): return np.asarray(z) """ - print() - node.dprint(depth=2, print_type=True) - print("subtensor_def_src:", subtensor_def_src) func = compile_function_src( subtensor_def_src, function_name=function_name, @@ -107,25 +103,6 @@ def {function_name}({", ".join(input_names)}): return numba_njit(func, boundscheck=True) -@numba_funcify.register(Subtensor) -def numba_funcify_subtensor_custom(op, node, **kwargs): - idxs = get_idx_list(node.inputs, op.idx_list) - - if ( - idxs - and not isinstance(idxs[0], slice) - and all(idx == slice(None) for idx in idxs[1:]) - ): - - @numba_njit - def scalar_subtensor_leading_dim(x, idx): - return x[idx] - - return scalar_subtensor_leading_dim - - return numba_funcify_default_subtensor(op, node, **kwargs) - - @numba_funcify.register(AdvancedSubtensor) @numba_funcify.register(AdvancedIncSubtensor) def numba_funcify_AdvancedSubtensor(op, node, **kwargs): diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 7052612e6e..74870e29bd 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -35,16 +35,6 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): on[...] = ton """ - if nin == 2 and nout == 1: - - @numba_basic.numba_njit - def store_core_outputs_2in1out(i0, i1, o0): - t0 = core_op_fn(i0, i1) - o0[...] = t0 - - return store_core_outputs_2in1out - print(nin, nout) - inputs = [f"i{i}" for i in range(nin)] outputs = [f"o{i}" for i in range(nout)] inner_outputs = [f"t{output}" for output in outputs] From 799edf03d11cbfc446814b5e54cf3fb5f1c3e889 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 28 Mar 2025 19:02:38 +0100 Subject: [PATCH 15/15] Cache stuff hazardly --- pytensor/link/numba/dispatch/elemwise.py | 8 +- .../link/numba/dispatch/vectorize_codegen.py | 20 ++- pytensor/link/numba/super_utils.py | 138 ++++++++++++++++++ 3 files changed, 160 insertions(+), 6 deletions(-) create mode 100644 pytensor/link/numba/super_utils.py diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9fd81dadcf..15933a8928 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -264,6 +264,12 @@ def axis_apply_fn(x): @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): + # op = getattr(np, str(op.scalar_op).lower()) + # @numba_njit + # def elemwise_is_numpy(x): + # return op(x) + # return elemwise_is_numpy + scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs] scalar_node = op.scalar_op.make_node(*scalar_inputs) @@ -276,7 +282,7 @@ def numba_funcify_Elemwise(op, node, **kwargs): nin = len(node.inputs) nout = len(node.outputs) - core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) + core_op_fn = store_core_outputs(scalar_op_fn, op.scalar_op, nin=nin, nout=nout) input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs) output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs) diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 74870e29bd..7c077e5eb2 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -3,6 +3,7 @@ import base64 import pickle from collections.abc import Callable, Sequence +from hashlib import sha256 from textwrap import indent from typing import Any, cast @@ -15,15 +16,19 @@ from numba.core.types.misc import NoneType from numba.np import arrayobj +from pytensor.graph.op import HasInnerGraph from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.utils import compile_function_src +from pytensor.link.numba.super_utils import compile_function_src2 +from pytensor.scalar import ScalarOp def encode_literals(literals: Sequence) -> str: return base64.encodebytes(pickle.dumps(literals)).decode() -def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable: +def store_core_outputs( + core_op_fn: Callable, core_op: ScalarOp, nin: int, nout: int +) -> Callable: """Create a Numba function that wraps a core function and stores its vectorized outputs. @njit @@ -52,9 +57,14 @@ def store_core_outputs({inp_signature}, {out_signature}): {indent(store_outputs, " " * 4)} """ global_env = {"core_op_fn": core_op_fn} - func = compile_function_src( - func_src, "store_core_outputs", {**globals(), **global_env} - ) + # func = compile_function_src( + # func_src, "store_core_outputs", {**globals(), **global_env}, + # ) + if isinstance(core_op, HasInnerGraph): + key = sha256(core_op.c_code_template.encode()).hexdigest() + else: + key = str(core_op) + func = compile_function_src2(key, func_src, "store_core_outputs", global_env) return cast(Callable, numba_basic.numba_njit(func)) diff --git a/pytensor/link/numba/super_utils.py b/pytensor/link/numba/super_utils.py new file mode 100644 index 0000000000..12b33004f6 --- /dev/null +++ b/pytensor/link/numba/super_utils.py @@ -0,0 +1,138 @@ +import importlib +import os +import sys +import tempfile +from collections.abc import Callable +from typing import Any + +import numba +import numba.core.caching +from numba.core.caching import CacheImpl + + +class PyTensorLoader(importlib.abc.SourceLoader): + def __init__(self): + # Key is "pytensor_generated_" + hash of pytensor graph + self._module_sources = {} + self._module_globals = {} + self._module_locals = {} + + def get_source(self, fullname): + if fullname not in self._module_sources: + raise ImportError() + return self._module_sources[fullname] + + def get_data(self, path): + if path not in self._module_sources: + raise ImportError() + return self._module_sources[path].encode("utf-8") + + def get_filename(self, path): + if path not in self._module_sources: + raise ImportError() + return path + + def add_module(self, name, src, global_env, local_env): + self._module_sources[name] = src + self._module_globals[name] = global_env + self._module_locals[name] = local_env + + def exec_module(self, module): + name = module.__name__ + variables = module.__dict__ + variables.update(self._module_globals[name]) + variables.update(self._module_locals[name]) + code = compile(self._module_sources[name], name, "exec") + exec(code, variables) + + def create_module(self, spec): + return None + + +pytensor_loader = PyTensorLoader() + + +def load_module(key, src, global_env, local_env): + pytensor_loader.add_module(key, src, global_env, local_env) + spec = importlib.util.spec_from_loader(key, pytensor_loader) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + sys.modules[key] = module + return module + + +class NumbaPyTensorCacheLocator(numba.core.caching._CacheLocator): + def __init__(self, py_func, py_file): + # print(f"New locator {py_func=}, {py_file=}") + self._py_func = py_func + self._py_file = py_file + self._hash = py_file + # src_hash = hash(pytensor_loader._module_sources[self._py_file]) + # self._hash = hash((src_hash, py_file, pytensor.__version__)) + + def ensure_cache_path(self): + path = self.get_cache_path() + os.makedirs(path, exist_ok=True) + # Ensure the directory is writable by trying to write a temporary file + tempfile.TemporaryFile(dir=path).close() + + def get_cache_path(self): + """ + Return the directory the function is cached in. + """ + return "~/.cache/pytensor" + + def get_source_stamp(self): + """ + Get a timestamp representing the source code's freshness. + Can return any picklable Python object. + """ + + return self._hash + + def get_disambiguator(self): + """ + Get a string disambiguator for this locator's function. + It should allow disambiguating different but similarly-named functions. + """ + return None + + @classmethod + def from_function(cls, py_func, py_file): + """ + Create a locator instance for the given function located in the + given file. + """ + if py_func.__module__ in pytensor_loader._module_sources: + return cls(py_func, py_file) + + +CacheImpl._locator_classes.append(NumbaPyTensorCacheLocator) + + +def compile_function_src2( + key: str, + src: str, + function_name: str, + global_env: dict[Any, Any] | None = None, + local_env: dict[Any, Any] | None = None, +) -> Callable: + # with NamedTemporaryFile(delete=False) as f: + # filename = f.name + # f.write(src.encode()) + + if global_env is None: + global_env = {} + + if local_env is None: + local_env = {} + + # mod_code = compile(src, filename, mode="exec") + # exec(mod_code, global_env, local_env) + # print(key, src) + module = load_module(key, src, global_env, local_env) + res = getattr(module, function_name) + + # res = cast(Callable, res) + # res.__source__ = src # type: ignore + return res