diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 8f5972c058..e9bf683696 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -68,7 +68,7 @@ def numba_funcify_Alloc(op, node, **kwargs): shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( - f"{item_name} = to_scalar({shape_name})" + f"{item_name} = {shape_name}.item()" for item_name, shape_name in zip( shape_var_item_names, shape_var_names, strict=True ) @@ -86,12 +86,11 @@ def numba_funcify_Alloc(op, node, **kwargs): alloc_def_src = f""" def alloc(val, {", ".join(shape_var_names)}): - val_np = np.asarray(val) {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} {check_runtime_broadcast_src} - res = np.empty(scalar_shape, dtype=val_np.dtype) - res[...] = val_np + res = np.empty(scalar_shape, dtype=val.dtype) + res[...] = val return res """ alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})