From 1d306f48cf2476776cb55f535e76a7c67c8cc84c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 7 Jun 2024 19:04:51 +0200 Subject: [PATCH] Simplify Numba implementation of Alloc --- pytensor/link/numba/dispatch/tensor_basic.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 8f5972c058..e9bf683696 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -68,7 +68,7 @@ def numba_funcify_Alloc(op, node, **kwargs): shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( - f"{item_name} = to_scalar({shape_name})" + f"{item_name} = {shape_name}.item()" for item_name, shape_name in zip( shape_var_item_names, shape_var_names, strict=True ) @@ -86,12 +86,11 @@ def numba_funcify_Alloc(op, node, **kwargs): alloc_def_src = f""" def alloc(val, {", ".join(shape_var_names)}): - val_np = np.asarray(val) {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} {check_runtime_broadcast_src} - res = np.empty(scalar_shape, dtype=val_np.dtype) - res[...] = val_np + res = np.empty(scalar_shape, dtype=val.dtype) + res[...] = val return res """ alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})