Skip to content

Commit 74d5973

Browse files
committed
Adding tests for validate and noticed that re_evaluate tests using local_dict argument are flawed and do not actually work
1 parent 0032150 commit 74d5973

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

numexpr/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
import os, os.path
3232
import platform
3333
from numexpr.expressions import E
34-
from numexpr.necompiler import NumExpr, disassemble, evaluate, re_evaluate
34+
from numexpr.necompiler import (NumExpr, disassemble, evaluate, re_evaluate,
35+
validate)
3536

3637
from numexpr.utils import (_init_num_threads,
3738
get_vml_version, set_vml_accuracy_mode, set_vml_num_threads,

numexpr/tests/test_numexpr.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from numpy import shape, allclose, array_equal, ravel, isnan, isinf
3232

3333
import numexpr
34-
from numexpr import E, NumExpr, evaluate, re_evaluate, disassemble, use_vml
34+
from numexpr import E, NumExpr, evaluate, re_evaluate, validate, disassemble, use_vml
3535
from numexpr.expressions import ConstantNode
3636

3737
import unittest
@@ -370,10 +370,38 @@ def test_re_evaluate(self):
370370
assert_array_equal(x, array([86., 124., 168.]))
371371

372372
def test_re_evaluate_dict(self):
373+
a1 = array([1., 2., 3.])
374+
b1 = array([4., 5., 6.])
375+
c1 = array([7., 8., 9.])
376+
x = evaluate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1})
377+
x = re_evaluate()
378+
assert_array_equal(x, array([86., 124., 168.]))
379+
380+
def test_validate(self):
373381
a = array([1., 2., 3.])
374382
b = array([4., 5., 6.])
375383
c = array([7., 8., 9.])
376-
x = evaluate("2*a + 3*b*c", local_dict={'a': a, 'b': b, 'c': c})
384+
retval = validate("2*a + 3*b*c")
385+
assert(retval is None)
386+
x = re_evaluate()
387+
assert_array_equal(x, array([86., 124., 168.]))
388+
389+
def test_validate_missing_var(self):
390+
a = array([1., 2., 3.])
391+
b = array([4., 5., 6.])
392+
retval = validate("2*a + 3*b*c")
393+
assert(isinstance(retval, KeyError))
394+
395+
def test_validate_syntax(self):
396+
retval = validate("2+")
397+
assert(isinstance(retval, SyntaxError))
398+
399+
def test_validate_dict(self):
400+
a1 = array([1., 2., 3.])
401+
b1 = array([4., 5., 6.])
402+
c1 = array([7., 8., 9.])
403+
retval = validate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1})
404+
assert(retval is None)
377405
x = re_evaluate()
378406
assert_array_equal(x, array([86., 124., 168.]))
379407

0 commit comments

Comments
 (0)