Skip to content

Commit 5143018

Browse files
brendan-m-murphyricardoV94
authored andcommitted
Fix test for neg on unsigned
Due to changes in numpy conversion rules (NEP 50), overflows are not ignored; in particular, negating a unsigned int causes an overflow error. The test for `neg` has been changed to check that this error is raised.
1 parent 4c8c8b6 commit 5143018

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

tests/tensor/test_math.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytensor.graph.fg import FunctionGraph
2424
from pytensor.graph.replace import vectorize_node
2525
from pytensor.link.c.basic import DualLinker
26+
from pytensor.npy_2_compat import using_numpy_2
2627
from pytensor.printing import pprint
2728
from pytensor.raise_op import Assert
2829
from pytensor.tensor import blas, blas_c
@@ -391,11 +392,20 @@ def test_maximum_minimum_grad():
391392
grad=_grad_broadcast_unary_normal,
392393
)
393394

395+
396+
# in numpy >= 2.0, negating a uint raises an error
397+
neg_good = _good_broadcast_unary_normal.copy()
398+
if using_numpy_2:
399+
neg_bad = {"uint8": neg_good.pop("uint8"), "uint16": neg_good.pop("uint16")}
400+
else:
401+
neg_bad = None
402+
394403
TestNegBroadcast = makeBroadcastTester(
395404
op=neg,
396405
expected=lambda x: -x,
397-
good=_good_broadcast_unary_normal,
406+
good=neg_good,
398407
grad=_grad_broadcast_unary_normal,
408+
bad_compile=neg_bad,
399409
)
400410

401411
TestSgnBroadcast = makeBroadcastTester(

tests/tensor/utils.py

+21
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def makeTester(
339339
good=None,
340340
bad_build=None,
341341
bad_runtime=None,
342+
bad_compile=None,
342343
grad=None,
343344
mode=None,
344345
grad_rtol=None,
@@ -373,6 +374,7 @@ def makeTester(
373374
_test_memmap = test_memmap
374375
_check_name = check_name
375376
_grad_eps = grad_eps
377+
_bad_compile = bad_compile or {}
376378

377379
class Checker:
378380
op = staticmethod(_op)
@@ -382,6 +384,7 @@ class Checker:
382384
good = _good
383385
bad_build = _bad_build
384386
bad_runtime = _bad_runtime
387+
bad_compile = _bad_compile
385388
grad = _grad
386389
mode = _mode
387390
skip = skip_
@@ -539,6 +542,24 @@ def test_bad_build(self):
539542
# instantiated on the following bad inputs: %s"
540543
# % (self.op, testname, node, inputs))
541544

545+
@config.change_flags(compute_test_value="off")
546+
@pytest.mark.skipif(skip, reason="Skipped")
547+
def test_bad_compile(self):
548+
for testname, inputs in self.bad_compile.items():
549+
inputrs = [shared(input) for input in inputs]
550+
try:
551+
node = safe_make_node(self.op, *inputrs)
552+
except Exception as exc:
553+
err_msg = (
554+
f"Test {self.op}::{testname}: Error occurred while trying"
555+
f" to make a node with inputs {inputs}"
556+
)
557+
exc.args += (err_msg,)
558+
raise
559+
560+
with pytest.raises(Exception):
561+
inplace_func([], node.outputs, mode=mode, name="test_bad_runtime")
562+
542563
@config.change_flags(compute_test_value="off")
543564
@pytest.mark.skipif(skip, reason="Skipped")
544565
def test_bad_runtime(self):

0 commit comments

Comments
 (0)