diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9677615206..48ada81faf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -193,7 +193,7 @@ jobs: else micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; fi - if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi + micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi pip install pytest-sphinx diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index ffa27e5d5a..adf9732f5e 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -63,9 +63,8 @@ def register_linker(name, linker): # If a string is passed as the optimizer argument in the constructor # for Mode, it will be used as the key to retrieve the real optimizer # in this dictionary -exclude = [] -if not config.cxx: - exclude = ["cxx_only"] + +exclude = ["cxx_only", "BlasOpt"] OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude) # Even if multiple merge optimizer call will be there, this shouldn't # impact performance. @@ -346,6 +345,11 @@ def __setstate__(self, state): optimizer = predefined_optimizers[optimizer] if isinstance(optimizer, RewriteDatabaseQuery): self.provided_optimizer = optimizer + + # Force numba-required rewrites if using NumbaLinker + if isinstance(linker, NumbaLinker): + optimizer = optimizer.including("numba") + self._optimizer = optimizer self.call_time = 0 self.fn_time = 0 @@ -443,16 +447,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): # string as the key # Use VM_linker to allow lazy evaluation by default. FAST_COMPILE = Mode( - VMLinker(use_cloop=False, c_thunks=False), - RewriteDatabaseQuery(include=["fast_compile", "py_only"]), + NumbaLinker(), + # TODO: Fast_compile should just use python code, CHANGE ME! + RewriteDatabaseQuery( + include=["fast_compile", "numba"], + exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], + ), +) +FAST_RUN = Mode( + NumbaLinker(), + RewriteDatabaseQuery( + include=["fast_run", "numba"], + exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], + ), ) -if config.cxx: - FAST_RUN = Mode("cvm", "fast_run") -else: - FAST_RUN = Mode( - "vm", - RewriteDatabaseQuery(include=["fast_run", "py_only"]), - ) NUMBA = Mode( NumbaLinker(), @@ -565,6 +573,7 @@ def register_mode(name, mode): Add a `Mode` which can be referred to by `name` in `function`. """ + # TODO: Remove me if name in predefined_modes: raise ValueError(f"Mode name already taken: {name}") predefined_modes[name] = mode diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index ca3c44bf6d..9c8fecab33 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -370,11 +370,21 @@ def add_compile_configvars(): if rc == 0 and config.cxx != "": # Keep the default linker the same as the one for the mode FAST_RUN - linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"] + linker_options = [ + "cvm", + "c|py", + "py", + "c", + "c|py_nogc", + "vm", + "vm_nogc", + "cvm_nogc", + "jax", + ] else: # g++ is not present or the user disabled it, # linker should default to python only. - linker_options = ["py", "vm_nogc"] + linker_options = ["py", "vm", "vm_nogc", "jax"] if type(config).cxx.is_default: # If the user provided an empty value for cxx, do not warn. _logger.warning( @@ -388,7 +398,7 @@ def add_compile_configvars(): "linker", "Default linker used if the pytensor flags mode is Mode", # Not mutable because the default mode is cached after the first use. - EnumStr("cvm", linker_options, mutable=False), + EnumStr("numba", linker_options, mutable=False), in_c_key=False, ) diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 06370b4514..5e1a87e31e 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -87,3 +87,6 @@ def create_thunk_inputs(self, storage_map): thunk_inputs.append(sinput) return thunk_inputs + + def __repr__(self): + return "JAXLinker()" diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 2f3cac6ea6..f961524fa0 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -40,6 +40,7 @@ from pytensor.tensor.slinalg import Solve from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import MakeSlice, NoneConst +from pytensor.typed_list import TypedListType def global_numba_func(func): @@ -135,6 +136,8 @@ def get_numba_type( return CSCMatrixType(numba_dtype) raise NotImplementedError() + elif isinstance(pytensor_type, TypedListType): + return numba.types.List(get_numba_type(pytensor_type.ttype)) else: raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") @@ -481,11 +484,11 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] func_conditions = [ - f"assert x.shape[{i}] == {shape_input_names}" - for i, (shape_input, shape_input_names) in enumerate( + f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'" + for i, (node_dim_input, eval_dim_name) in enumerate( zip(shape_inputs, shape_input_names, strict=True) ) - if shape_input is not NoneConst + if node_dim_input is not NoneConst ] func = dedent( diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9fd81dadcf..15933a8928 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -264,6 +264,12 @@ def axis_apply_fn(x): @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): + # op = getattr(np, str(op.scalar_op).lower()) + # @numba_njit + # def elemwise_is_numpy(x): + # return op(x) + # return elemwise_is_numpy + scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs] scalar_node = op.scalar_op.make_node(*scalar_inputs) @@ -276,7 +282,7 @@ def numba_funcify_Elemwise(op, node, **kwargs): nin = len(node.inputs) nout = len(node.outputs) - core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) + core_op_fn = store_core_outputs(scalar_op_fn, op.scalar_op, nin=nin, nout=nout) input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs) output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index e9b637b00f..31904eac04 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -23,6 +23,7 @@ Composite, Identity, Mul, + Pow, Reciprocal, ScalarOp, Second, @@ -154,6 +155,21 @@ def numba_funcify_Switch(op, node, **kwargs): return numba_basic.global_numba_func(switch) +@numba_funcify.register(Pow) +def numba_funcify_Pow(op, node, **kwargs): + pow_dtype = node.inputs[1].type.dtype + + def pow(x, y): + return x**y + + # Work-around https://github.com/numba/numba/issues/9554 + # fast-math casuse kernel crash + patch_kwargs = {} + if pow_dtype.startswith("int"): + patch_kwargs["fastmath"] = False + return numba_basic.numba_njit(**patch_kwargs)(pow) + + def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str): """Create a Numba-compatible N-ary function from a binary function.""" unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_") diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 8f5972c058..30e89c4256 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -68,7 +68,7 @@ def numba_funcify_Alloc(op, node, **kwargs): shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( - f"{item_name} = to_scalar({shape_name})" + f"{item_name} = {shape_name}.item()" for item_name, shape_name in zip( shape_var_item_names, shape_var_names, strict=True ) @@ -86,7 +86,7 @@ def numba_funcify_Alloc(op, node, **kwargs): alloc_def_src = f""" def alloc(val, {", ".join(shape_var_names)}): - val_np = np.asarray(val) + val_np = val {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} {check_runtime_broadcast_src} diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 74870e29bd..7c077e5eb2 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -3,6 +3,7 @@ import base64 import pickle from collections.abc import Callable, Sequence +from hashlib import sha256 from textwrap import indent from typing import Any, cast @@ -15,15 +16,19 @@ from numba.core.types.misc import NoneType from numba.np import arrayobj +from pytensor.graph.op import HasInnerGraph from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.utils import compile_function_src +from pytensor.link.numba.super_utils import compile_function_src2 +from pytensor.scalar import ScalarOp def encode_literals(literals: Sequence) -> str: return base64.encodebytes(pickle.dumps(literals)).decode() -def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable: +def store_core_outputs( + core_op_fn: Callable, core_op: ScalarOp, nin: int, nout: int +) -> Callable: """Create a Numba function that wraps a core function and stores its vectorized outputs. @njit @@ -52,9 +57,14 @@ def store_core_outputs({inp_signature}, {out_signature}): {indent(store_outputs, " " * 4)} """ global_env = {"core_op_fn": core_op_fn} - func = compile_function_src( - func_src, "store_core_outputs", {**globals(), **global_env} - ) + # func = compile_function_src( + # func_src, "store_core_outputs", {**globals(), **global_env}, + # ) + if isinstance(core_op, HasInnerGraph): + key = sha256(core_op.c_code_template.encode()).hexdigest() + else: + key = str(core_op) + func = compile_function_src2(key, func_src, "store_core_outputs", global_env) return cast(Callable, numba_basic.numba_njit(func)) diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 553c5ef217..3cf542a316 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -35,3 +35,6 @@ def create_thunk_inputs(self, storage_map): thunk_inputs.append(sinput) return thunk_inputs + + def __repr__(self): + return "NumbaLinker()" diff --git a/pytensor/link/numba/super_utils.py b/pytensor/link/numba/super_utils.py new file mode 100644 index 0000000000..12b33004f6 --- /dev/null +++ b/pytensor/link/numba/super_utils.py @@ -0,0 +1,138 @@ +import importlib +import os +import sys +import tempfile +from collections.abc import Callable +from typing import Any + +import numba +import numba.core.caching +from numba.core.caching import CacheImpl + + +class PyTensorLoader(importlib.abc.SourceLoader): + def __init__(self): + # Key is "pytensor_generated_" + hash of pytensor graph + self._module_sources = {} + self._module_globals = {} + self._module_locals = {} + + def get_source(self, fullname): + if fullname not in self._module_sources: + raise ImportError() + return self._module_sources[fullname] + + def get_data(self, path): + if path not in self._module_sources: + raise ImportError() + return self._module_sources[path].encode("utf-8") + + def get_filename(self, path): + if path not in self._module_sources: + raise ImportError() + return path + + def add_module(self, name, src, global_env, local_env): + self._module_sources[name] = src + self._module_globals[name] = global_env + self._module_locals[name] = local_env + + def exec_module(self, module): + name = module.__name__ + variables = module.__dict__ + variables.update(self._module_globals[name]) + variables.update(self._module_locals[name]) + code = compile(self._module_sources[name], name, "exec") + exec(code, variables) + + def create_module(self, spec): + return None + + +pytensor_loader = PyTensorLoader() + + +def load_module(key, src, global_env, local_env): + pytensor_loader.add_module(key, src, global_env, local_env) + spec = importlib.util.spec_from_loader(key, pytensor_loader) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + sys.modules[key] = module + return module + + +class NumbaPyTensorCacheLocator(numba.core.caching._CacheLocator): + def __init__(self, py_func, py_file): + # print(f"New locator {py_func=}, {py_file=}") + self._py_func = py_func + self._py_file = py_file + self._hash = py_file + # src_hash = hash(pytensor_loader._module_sources[self._py_file]) + # self._hash = hash((src_hash, py_file, pytensor.__version__)) + + def ensure_cache_path(self): + path = self.get_cache_path() + os.makedirs(path, exist_ok=True) + # Ensure the directory is writable by trying to write a temporary file + tempfile.TemporaryFile(dir=path).close() + + def get_cache_path(self): + """ + Return the directory the function is cached in. + """ + return "~/.cache/pytensor" + + def get_source_stamp(self): + """ + Get a timestamp representing the source code's freshness. + Can return any picklable Python object. + """ + + return self._hash + + def get_disambiguator(self): + """ + Get a string disambiguator for this locator's function. + It should allow disambiguating different but similarly-named functions. + """ + return None + + @classmethod + def from_function(cls, py_func, py_file): + """ + Create a locator instance for the given function located in the + given file. + """ + if py_func.__module__ in pytensor_loader._module_sources: + return cls(py_func, py_file) + + +CacheImpl._locator_classes.append(NumbaPyTensorCacheLocator) + + +def compile_function_src2( + key: str, + src: str, + function_name: str, + global_env: dict[Any, Any] | None = None, + local_env: dict[Any, Any] | None = None, +) -> Callable: + # with NamedTemporaryFile(delete=False) as f: + # filename = f.name + # f.write(src.encode()) + + if global_env is None: + global_env = {} + + if local_env is None: + local_env = {} + + # mod_code = compile(src, filename, mode="exec") + # exec(mod_code, global_env, local_env) + # print(key, src) + module = load_module(key, src, global_env, local_env) + res = getattr(module, function_name) + + # res = cast(Callable, res) + # res.__source__ = src # type: ignore + return res diff --git a/tests/tensor/conv/test_abstract_conv.py b/tests/tensor/conv/test_abstract_conv.py index 23ba23e1e9..814f1eb80b 100644 --- a/tests/tensor/conv/test_abstract_conv.py +++ b/tests/tensor/conv/test_abstract_conv.py @@ -948,9 +948,9 @@ def run_gradinput( ) -@pytest.mark.skipif( - config.cxx == "", - reason="SciPy and cxx needed", +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") +@pytest.mark.xfail( + reason="Involves Ops with no Python implementation for numba to use as fallback" ) class TestAbstractConvNoOptim(BaseTestConv2d): @classmethod @@ -1884,9 +1884,9 @@ def test_conv2d_grad_wrt_weights(self): ) -@pytest.mark.skipif( - config.cxx == "", - reason="SciPy and cxx needed", +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") +@pytest.mark.xfail( + reason="Involves Ops with no Python implementation for numba to use as fallback" ) class TestGroupedConvNoOptim: conv = abstract_conv.AbstractConv2d @@ -2096,9 +2096,9 @@ def conv_gradinputs(filters_val, output_val): utt.verify_grad(conv_gradinputs, [kern, top], mode=self.mode, eps=1) -@pytest.mark.skipif( - config.cxx == "", - reason="SciPy and cxx needed", +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") +@pytest.mark.xfail( + reason="Involves Ops with no Python implementation for numba to use as fallback" ) class TestGroupedConv3dNoOptim(TestGroupedConvNoOptim): conv = abstract_conv.AbstractConv3d diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 1730ae46ac..93635aeb4a 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -142,55 +142,68 @@ def rewrite(g, level="fast_run"): return g -def test_local_useless_slice(): - # test a simple matrix - x = matrix("x") - mode_excluding = get_default_mode().excluding( - "local_useless_slice", "local_mul_canonizer" - ) - mode_including = ( - get_default_mode() - .including("local_useless_slice") - .excluding("local_mul_canonizer") - ) +class TestME: + def local_useless_slice_tester(self): + # test a simple matrix + x = matrix("x") + mode_excluding = get_mode("NUMBA").excluding( + "local_useless_slice", "local_mul_canonizer" + ) + mode_including = ( + get_mode("NUMBA") + .including("local_useless_slice") + .excluding("local_mul_canonizer") + ) - # test with and without the useless slice - o = 2 * x[0, :] - f_excluding = function([x], o, mode=mode_excluding) - f_including = function([x], o, mode=mode_including) - rng = np.random.default_rng(utt.fetch_seed()) - test_inp = rng.integers(-10, 10, (4, 4)).astype("float32") - assert all(f_including(test_inp) == f_excluding(test_inp)) - # test to see if the slice is truly gone - apply_node = f_including.maker.fgraph.toposort()[0] - subtens = apply_node.op - assert not any(isinstance(idx, slice) for idx in subtens.idx_list) - - # Now test that the stack trace is copied over properly, - # before before and after rewriting. - assert check_stack_trace(f_excluding, ops_to_check="all") - assert check_stack_trace(f_including, ops_to_check="all") - - # test a 4d tensor - z = tensor4("z") - o2 = z[1, :, :, 1] - o3 = z[0, :, :, :] - f_including_check = function([z], o2, mode=mode_including) - f_including_check_apply = function([z], o3, mode=mode_including) - - # The rewrite shouldn't apply here - apply_node = f_including_check.maker.fgraph.toposort()[0] - subtens = apply_node.op - assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2 - # But it should here - apply_node = f_including_check_apply.maker.fgraph.toposort()[0] - subtens = apply_node.op - assert not any(isinstance(idx, slice) for idx in subtens.idx_list) - - # Finally, test that the stack trace is copied over properly, - # before before and after rewriting. - assert check_stack_trace(f_including_check, ops_to_check=Subtensor) - assert check_stack_trace(f_including_check_apply, ops_to_check=Subtensor) + # test with and without the useless slice + o = 2 * x[0, :] + f_excluding = function([x], o, mode=mode_excluding) + f_including = function([x], o, mode=mode_including) + rng = np.random.default_rng(utt.fetch_seed()) + test_inp = rng.integers(-10, 10, (4, 4)).astype("float32") + assert all(f_including(test_inp) == f_excluding(test_inp)) + # test to see if the slice is truly gone + apply_node = f_including.maker.fgraph.toposort()[0] + subtens = apply_node.op + assert not any(isinstance(idx, slice) for idx in subtens.idx_list) + + # Now test that the stack trace is copied over properly, + # before before and after rewriting. + assert check_stack_trace(f_excluding, ops_to_check="all") + assert check_stack_trace(f_including, ops_to_check="all") + + # test a 4d tensor + z = tensor4("z") + o2 = z[1, :, :, 1] + o3 = z[0, :, :, :] + f_including_check = function([z], o2, mode=mode_including) + f_including_check_apply = function([z], o3, mode=mode_including) + + # The rewrite shouldn't apply here + apply_node = f_including_check.maker.fgraph.toposort()[0] + subtens = apply_node.op + assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2 + # But it should here + apply_node = f_including_check_apply.maker.fgraph.toposort()[0] + subtens = apply_node.op + assert not any(isinstance(idx, slice) for idx in subtens.idx_list) + + # Finally, test that the stack trace is copied over properly, + # before before and after rewriting. + assert check_stack_trace(f_including_check, ops_to_check=Subtensor) + assert check_stack_trace(f_including_check_apply, ops_to_check=Subtensor) + + def test_t0(self): + import pytensor + + x = pt.vector("x") + pytensor.function([x], x + 1) + + def test_t1(self): + self.local_useless_slice_tester() + + def test_t2(self): + self.local_useless_slice_tester() def test_local_useless_fill(): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index dee0023efd..dfe9f2630f 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -150,7 +150,10 @@ ) -pytestmark = pytest.mark.filterwarnings("error") +pytestmark = pytest.mark.filterwarnings( + "error", + "ignore:Numba will use object mode:UserWarning", +) if config.mode == "FAST_COMPILE": mode_opt = "FAST_RUN" @@ -758,41 +761,43 @@ def check_allocs_in_fgraph(fgraph, n): def setup_method(self): self.rng = np.random.default_rng(seed=utt.fetch_seed()) - def test_alloc_constant_folding(self): + @pytest.mark.parametrize( + "subtensor_fn, expected_grad_n_alloc", + [ + # IncSubtensor1 + (lambda x: x[:60], 1), + # AdvancedIncSubtensor1 + (lambda x: x[np.arange(60)], 1), + # AdvancedIncSubtensor + (lambda x: x[np.arange(50), np.arange(50)], 1), + ], + ) + def test_alloc_constant_folding(self, subtensor_fn, expected_grad_n_alloc): test_params = np.asarray(self.rng.standard_normal(50 * 60), self.dtype) some_vector = vector("some_vector", dtype=self.dtype) some_matrix = some_vector.reshape((60, 50)) variables = self.shared(np.ones((50,), dtype=self.dtype)) - idx = constant(np.arange(50)) - for alloc_, (subtensor, n_alloc) in zip( - self.allocs, - [ - # IncSubtensor1 - (some_matrix[:60], 2), - # AdvancedIncSubtensor1 - (some_matrix[arange(60)], 2), - # AdvancedIncSubtensor - (some_matrix[idx, idx], 1), - ], - strict=True, - ): - derp = pt_sum(dense_dot(subtensor, variables)) - - fobj = pytensor.function([some_vector], derp, mode=self.mode) - grad_derp = pytensor.grad(derp, some_vector) - fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode) + subtensor = subtensor_fn(some_matrix) - topo_obj = fobj.maker.fgraph.toposort() - assert sum(isinstance(node.op, type(alloc_)) for node in topo_obj) == 0 + derp = pt_sum(dense_dot(subtensor, variables)) + fobj = pytensor.function([some_vector], derp, mode=self.mode) + assert ( + sum(isinstance(node.op, Alloc) for node in fobj.maker.fgraph.apply_nodes) + == 0 + ) + # TODO: Assert something about the value if we bothered to call it? + fobj(test_params) - topo_grad = fgrad.maker.fgraph.toposort() - assert ( - sum(isinstance(node.op, type(alloc_)) for node in topo_grad) == n_alloc - ), (alloc_, subtensor, n_alloc, topo_grad) - fobj(test_params) - fgrad(test_params) + grad_derp = pytensor.grad(derp, some_vector) + fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode) + assert ( + sum(isinstance(node.op, Alloc) for node in fgrad.maker.fgraph.apply_nodes) + == expected_grad_n_alloc + ) + # TODO: Assert something about the value if we bothered to call it? + fgrad(test_params) def test_alloc_output(self): val = constant(self.rng.standard_normal((1, 1)), dtype=self.dtype)