Skip to content

Commit

Permalink
Use disable numba JIT
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create committed Mar 14, 2023
1 parent 984ee55 commit 402a38f
Showing 1 changed file with 9 additions and 69 deletions.
78 changes: 9 additions & 69 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import contextlib
import inspect
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
from unittest import mock

import numba
import numpy as np
Expand Down Expand Up @@ -108,73 +106,15 @@ def compare_shape_dtype(x, y):
def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
"""Evaluate the Numba implementation in pure Python for coverage purposes."""

def py_tuple_setitem(t, i, v):
ll = list(t)
ll[i] = v
return tuple(ll)

def py_to_scalar(x):
if isinstance(x, np.ndarray):
return x.item()
else:
return x

def njit_noop(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
else:
return lambda x: x

def vectorize_noop(*args, **kwargs):
def wrap(fn):
# `numba.vectorize` allows an `out` positional argument. We need
# to account for that
sig = inspect.signature(fn)
nparams = len(sig.parameters)

def inner_vec(*args):
if len(args) > nparams:
# An `out` argument has been specified for an in-place
# operation
out = args[-1]
out[...] = np.vectorize(fn)(*args[:nparams])
return out
else:
return np.vectorize(fn)(*args)

return inner_vec

if len(args) == 1 and callable(args[0]):
return wrap(args[0], **kwargs)
else:
return wrap

mocks = [
mock.patch("numba.njit", njit_noop),
mock.patch("numba.vectorize", vectorize_noop),
mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem),
mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop),
mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop),
mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x),
mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar),
mock.patch(
"aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
),
mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)),
]

with contextlib.ExitStack() as stack:
for ctx in mocks:
stack.enter_context(ctx)

aesara_numba_fn = function(
fn_inputs,
fn_outputs,
mode=mode,
accept_inplace=True,
)
_ = aesara_numba_fn(*inputs)
numba.config.DISABLE_JIT = True
aesara_numba_fn = function(
fn_inputs,
fn_outputs,
mode=mode,
accept_inplace=True,
)
_ = aesara_numba_fn(*inputs)
numba.config.DISABLE_JIT = False


def compare_numba_and_py(
Expand Down

0 comments on commit 402a38f

Please sign in to comment.