Skip to content

Commit 4b2d89c

Browse files
committed
Add in protections against call to eval(expression)
1 parent 74d5973 commit 4b2d89c

File tree

2 files changed

+61
-15
lines changed

2 files changed

+61
-15
lines changed

numexpr/necompiler.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414
import numpy
1515
import threading
16+
import re
1617

1718
is_cpu_amd_intel = False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE
1819
from numexpr import interpreter, expressions, use_vml
@@ -259,10 +260,17 @@ def __init__(self, astnode):
259260
def __str__(self):
260261
return 'Immediate(%d)' % (self.node.value,)
261262

262-
263+
_forbidden_re = re.compile('[\;[\:]|__')
263264
def stringToExpression(s, types, context):
264265
"""Given a string, convert it to a tree of ExpressionNode's.
265266
"""
267+
# sanitize the string for obvious attack vectors that NumExpr cannot
268+
# parse into its homebrew AST. This is to protect the call to `eval` below.
269+
# We forbid `;`, `:`. `[` and `__`
270+
# We would like to forbid `.` but it is both a reference and decimal point.
271+
if _forbidden_re.search(s) is not None:
272+
raise ValueError(f'Expression {s} has forbidden control characters.')
273+
266274
old_ctx = expressions._context.get_current_context()
267275
try:
268276
expressions._context.set_new_context(context)
@@ -285,8 +293,10 @@ def stringToExpression(s, types, context):
285293
t = types.get(name, default_type)
286294
names[name] = expressions.VariableNode(name, type_to_kind[t])
287295
names.update(expressions.functions)
296+
288297
# now build the expression
289298
ex = eval(c, names)
299+
290300
if expressions.isConstant(ex):
291301
ex = expressions.ConstantNode(ex, expressions.getKind(ex))
292302
elif not isinstance(ex, expressions.ExpressionNode):
@@ -611,9 +621,7 @@ def NumExpr(ex, signature=(), **kwargs):
611621
612622
Returns a `NumExpr` object containing the compiled function.
613623
"""
614-
# NumExpr can be called either directly by the end-user, in which case
615-
# kwargs need to be sanitized by getContext, or by evaluate,
616-
# in which case kwargs are in already sanitized.
624+
617625
# In that case _frame_depth is wrong (it should be 2) but it doesn't matter
618626
# since it will not be used (because truediv='auto' has already been
619627
# translated to either True or False).
@@ -758,7 +766,7 @@ def getArguments(names, local_dict=None, global_dict=None, _frame_depth: int=2):
758766
_names_cache = CacheDict(256)
759767
_numexpr_cache = CacheDict(256)
760768
_numexpr_last = {}
761-
769+
_numexpr_sanity = set()
762770
evaluate_lock = threading.Lock()
763771

764772
# MAYBE: decorate this function to add attributes instead of having the
@@ -861,7 +869,7 @@ def evaluate(ex: str,
861869
out: numpy.ndarray = None,
862870
order: str = 'K',
863871
casting: str = 'safe',
864-
_frame_depth: int=3,
872+
_frame_depth: int = 3,
865873
**kwargs) -> numpy.ndarray:
866874
"""
867875
Evaluate a simple array expression element-wise using the virtual machine.
@@ -909,6 +917,8 @@ def evaluate(ex: str,
909917
_frame_depth: int
910918
The calling frame depth. Unless you are a NumExpr developer you should
911919
not set this value.
920+
921+
912922
"""
913923
# We could avoid code duplication if we called validate and then re_evaluate
914924
# here, but they we have difficulties with the `sys.getframe(2)` call in
@@ -921,10 +931,6 @@ def evaluate(ex: str,
921931
else:
922932
raise e
923933

924-
925-
926-
927-
928934
def re_evaluate(local_dict: Optional[Dict] = None,
929935
_frame_depth: int=2) -> numpy.ndarray:
930936
"""

numexpr/tests/test_numexpr.py

+45-5
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,9 @@ def test_re_evaluate_dict(self):
373373
a1 = array([1., 2., 3.])
374374
b1 = array([4., 5., 6.])
375375
c1 = array([7., 8., 9.])
376-
x = evaluate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1})
377-
x = re_evaluate()
376+
local_dict={'a': a1, 'b': b1, 'c': c1}
377+
x = evaluate("2*a + 3*b*c", local_dict=local_dict)
378+
x = re_evaluate(local_dict=local_dict)
378379
assert_array_equal(x, array([86., 124., 168.]))
379380

380381
def test_validate(self):
@@ -400,9 +401,10 @@ def test_validate_dict(self):
400401
a1 = array([1., 2., 3.])
401402
b1 = array([4., 5., 6.])
402403
c1 = array([7., 8., 9.])
403-
retval = validate("2*a + 3*b*c", local_dict={'a': a1, 'b': b1, 'c': c1})
404+
local_dict={'a': a1, 'b': b1, 'c': c1}
405+
retval = validate("2*a + 3*b*c", local_dict=local_dict)
404406
assert(retval is None)
405-
x = re_evaluate()
407+
x = re_evaluate(local_dict=local_dict)
406408
assert_array_equal(x, array([86., 124., 168.]))
407409

408410
# Test for issue #22
@@ -502,11 +504,49 @@ def test_illegal_value(self):
502504
a = arange(3)
503505
try:
504506
evaluate("a < [0, 0, 0]")
505-
except TypeError:
507+
except (ValueError, TypeError):
508+
pass
509+
else:
510+
self.fail()
511+
512+
def test_forbidden_tokens(self):
513+
# Forbid dunder
514+
try:
515+
evaluate('__builtins__')
516+
except ValueError:
517+
pass
518+
else:
519+
self.fail()
520+
521+
# Forbid colon for lambda funcs
522+
try:
523+
evaluate('lambda x: x')
524+
except ValueError:
525+
pass
526+
else:
527+
self.fail()
528+
529+
# Forbid indexing
530+
try:
531+
evaluate('locals()[]')
532+
except ValueError:
506533
pass
507534
else:
508535
self.fail()
509536

537+
# Forbid semicolon
538+
try:
539+
evaluate('import os; os.cpu_count()')
540+
except ValueError:
541+
pass
542+
else:
543+
self.fail()
544+
545+
# I struggle to come up with cases for our ban on `'` and `"`
546+
547+
548+
549+
510550
def test_disassemble(self):
511551
assert_equal(disassemble(NumExpr(
512552
"where(m, a, -1)", [('m', bool), ('a', float)])),

0 commit comments

Comments
 (0)