Skip to content

Commit 8267bf9

Browse files
authored
Merge pull request #101 from SwayamInSync/100-fix
Fix #100: Accept int/float subclasses in QuadPrecision comparisons
2 parents 3164b96 + 0f628ed commit 8267bf9

2 files changed

Lines changed: 28 additions & 1 deletion

File tree

src/csrc/scalar_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ quad_richcompare(QuadPrecisionObject *self, PyObject *other, int cmp_op)
144144
return NULL;
145145
}
146146
}
147-
else if (PyLong_CheckExact(other) || PyFloat_CheckExact(other)) {
147+
else if (PyLong_Check(other) || PyFloat_Check(other)) {
148148
other_quad = QuadPrecision_from_object(other, backend);
149149
if (other_quad == NULL) {
150150
return NULL;

tests/test_quaddtype.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,33 @@ def test_comparisons(op, a, b):
13661366
assert op_func(quad_a, quad_b) == op_func(float_a, float_b)
13671367

13681368

1369+
@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"])
1370+
@pytest.mark.parametrize(
1371+
"quad_val, other",
1372+
[
1373+
# bool is a subclass of int — exercises PyLong_Check (regression: gh-100)
1374+
(1, True),
1375+
(0, False),
1376+
(1, False),
1377+
(0, True),
1378+
(2, True),
1379+
# np.float64 is a subclass of float — exercises PyFloat_Check
1380+
(1, np.float64(1.0)),
1381+
(1, np.float64(2.0)),
1382+
(0, np.float64(-0.0)),
1383+
],
1384+
)
1385+
def test_comparisons_with_python_subclasses(op, quad_val, other):
1386+
op_func = getattr(operator, op)
1387+
quad_a = QuadPrecision(quad_val)
1388+
expected = op_func(float(quad_val), float(other))
1389+
1390+
# Forward: QuadPrecision OP subclass-instance
1391+
assert op_func(quad_a, other) == expected, f"Failed {op} between QuadPrecision({quad_val}) and {other} (type {type(other)})"
1392+
# Reverse: subclass-instance OP QuadPrecision
1393+
assert op_func(other, quad_a) == op_func(float(other), float(quad_val)), f"Failed {op} between {other} (type {type(other)}) and QuadPrecision({quad_val})"
1394+
1395+
13691396
@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"])
13701397
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
13711398
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])

0 commit comments

Comments
 (0)