Skip to content

Commit ab11e20

Browse files
authored
Copy _Py_CAST() from Python main branch (#40)
* Fix C++ compatibility. * Add more tests on C++.
1 parent 2df7edd commit ab11e20

File tree

3 files changed

+136
-52
lines changed

3 files changed

+136
-52
lines changed

pythoncapi_compat.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,30 @@ extern "C" {
3535
// C++ compatibility: _Py_CAST() and _Py_NULL
3636
#ifndef _Py_CAST
3737
# ifdef __cplusplus
38-
# define _Py_CAST(type, expr) \
39-
const_cast<type>(reinterpret_cast<const type>(expr))
38+
extern "C++" {
39+
namespace {
40+
template <typename type, typename expr_type>
41+
inline type _Py_CAST_impl(expr_type *expr) {
42+
return reinterpret_cast<type>(expr);
43+
}
44+
45+
template <typename type, typename expr_type>
46+
inline type _Py_CAST_impl(expr_type const *expr) {
47+
return reinterpret_cast<type>(const_cast<expr_type *>(expr));
48+
}
49+
50+
template <typename type, typename expr_type>
51+
inline type _Py_CAST_impl(expr_type &expr) {
52+
return static_cast<type>(expr);
53+
}
54+
55+
template <typename type, typename expr_type>
56+
inline type _Py_CAST_impl(expr_type const &expr) {
57+
return static_cast<type>(const_cast<expr_type &>(expr));
58+
}
59+
}
60+
}
61+
# define _Py_CAST(type, expr) _Py_CAST_impl<type>(expr)
4062
# else
4163
# define _Py_CAST(type, expr) ((type)(expr))
4264
# endif

tests/test_pythoncapi_compat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,9 @@ def main():
150150
VERBOSE = ("-v" in sys.argv[1:] or "--verbose" in sys.argv[1:])
151151

152152
# Implementing PyFrame_GetLocals() and PyCode_GetCode() require the
153-
# internal C API in Python 3.11 alpha versions
154-
if 0x30b0000 <= sys.hexversion < 0x30b00b1:
153+
# internal C API in Python 3.11 alpha versions. Skip also Python 3.11b1
154+
# which has issues with C++ casts: _Py_CAST() macro.
155+
if 0x30b0000 <= sys.hexversion <= 0x30b00b1:
155156
version = sys.version.split()[0]
156157
print("SKIP TESTS: Python %s is not supported" % version)
157158
return

tests/test_pythoncapi_compat_cext.c

Lines changed: 109 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ static PyObject *
3636
test_object(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
3737
{
3838
PyObject *obj = PyList_New(0);
39-
if (obj == NULL) {
40-
return NULL;
39+
if (obj == _Py_NULL) {
40+
return _Py_NULL;
4141
}
4242
Py_ssize_t refcnt = Py_REFCNT(obj);
4343

@@ -53,7 +53,7 @@ test_object(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
5353
assert(Py_REFCNT(obj) == (refcnt + 1));
5454
Py_DECREF(xref);
5555

56-
assert(Py_XNewRef(NULL) == NULL);
56+
assert(Py_XNewRef(_Py_NULL) == _Py_NULL);
5757

5858
// Py_SETREF()
5959
PyObject *setref = Py_NewRef(obj);
@@ -65,20 +65,20 @@ test_object(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
6565
assert(Py_REFCNT(obj) == refcnt);
6666
Py_INCREF(setref);
6767

68-
Py_SETREF(setref, NULL);
69-
assert(setref == NULL);
68+
Py_SETREF(setref, _Py_NULL);
69+
assert(setref == _Py_NULL);
7070

7171
// Py_XSETREF()
72-
PyObject *xsetref = NULL;
72+
PyObject *xsetref = _Py_NULL;
7373

7474
Py_INCREF(obj);
7575
assert(Py_REFCNT(obj) == (refcnt + 1));
7676
Py_XSETREF(xsetref, obj);
7777
assert(xsetref == obj);
7878

79-
Py_XSETREF(xsetref, NULL);
79+
Py_XSETREF(xsetref, _Py_NULL);
8080
assert(Py_REFCNT(obj) == refcnt);
81-
assert(xsetref == NULL);
81+
assert(xsetref == _Py_NULL);
8282

8383
// Py_SET_REFCNT
8484
Py_SET_REFCNT(obj, Py_REFCNT(obj));
@@ -103,8 +103,8 @@ test_py_is(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
103103
PyObject *o_true = Py_True;
104104
PyObject *o_false = Py_False;
105105
PyObject *obj = PyList_New(0);
106-
if (obj == NULL) {
107-
return NULL;
106+
if (obj == _Py_NULL) {
107+
return _Py_NULL;
108108
}
109109

110110
/* test Py_Is() */
@@ -138,9 +138,9 @@ test_frame(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
138138

139139
// test PyThreadState_GetFrame()
140140
PyFrameObject *frame = PyThreadState_GetFrame(tstate);
141-
if (frame == NULL) {
141+
if (frame == _Py_NULL) {
142142
PyErr_SetString(PyExc_AssertionError, "PyThreadState_GetFrame failed");
143-
return NULL;
143+
return _Py_NULL;
144144
}
145145

146146
// test _PyThreadState_GetFrameBorrow()
@@ -151,7 +151,7 @@ test_frame(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
151151

152152
// test PyFrame_GetCode()
153153
PyCodeObject *code = PyFrame_GetCode(frame);
154-
assert(code != NULL);
154+
assert(code != _Py_NULL);
155155
assert(PyCode_Check(code));
156156

157157
// test _PyFrame_GetCodeBorrow()
@@ -163,12 +163,12 @@ test_frame(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
163163

164164
// PyFrame_GetBack()
165165
PyFrameObject* back = PyFrame_GetBack(frame);
166-
if (back != NULL) {
166+
if (back != _Py_NULL) {
167167
assert(PyFrame_Check(back));
168168
}
169169

170170
// test _PyFrame_GetBackBorrow()
171-
if (back != NULL) {
171+
if (back != _Py_NULL) {
172172
Py_ssize_t back_refcnt = Py_REFCNT(back);
173173
PyFrameObject *back2 = _PyFrame_GetBackBorrow(frame);
174174
assert(back2 == back);
@@ -182,17 +182,17 @@ test_frame(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
182182

183183
// test PyFrame_GetLocals()
184184
PyObject *locals = PyFrame_GetLocals(frame);
185-
assert(locals != NULL);
185+
assert(locals != _Py_NULL);
186186
assert(PyDict_Check(locals));
187187

188188
// test PyFrame_GetGlobals()
189189
PyObject *globals = PyFrame_GetGlobals(frame);
190-
assert(globals != NULL);
190+
assert(globals != _Py_NULL);
191191
assert(PyDict_Check(globals));
192192

193193
// test PyFrame_GetBuiltins()
194194
PyObject *builtins = PyFrame_GetBuiltins(frame);
195-
assert(builtins != NULL);
195+
assert(builtins != _Py_NULL);
196196
assert(PyDict_Check(builtins));
197197

198198
assert(locals != globals);
@@ -220,12 +220,12 @@ test_thread_state(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
220220

221221
// test PyThreadState_GetInterpreter()
222222
PyInterpreterState *interp = PyThreadState_GetInterpreter(tstate);
223-
assert(interp != NULL);
223+
assert(interp != _Py_NULL);
224224

225225
#if !defined(PYPY_VERSION)
226226
// test PyThreadState_GetFrame()
227227
PyFrameObject *frame = PyThreadState_GetFrame(tstate);
228-
if (frame != NULL) {
228+
if (frame != _Py_NULL) {
229229
assert(PyFrame_Check(frame));
230230
}
231231
Py_XDECREF(frame);
@@ -251,7 +251,7 @@ test_interpreter(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
251251
{
252252
// test PyInterpreterState_Get()
253253
PyInterpreterState *interp = PyInterpreterState_Get();
254-
assert(interp != NULL);
254+
assert(interp != _Py_NULL);
255255
PyThreadState *tstate = PyThreadState_Get();
256256
PyInterpreterState *interp2 = PyThreadState_GetInterpreter(tstate);
257257
assert(interp == interp2);
@@ -267,21 +267,21 @@ test_calls(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
267267

268268
// test PyObject_CallNoArgs(): str() returns ''
269269
PyObject *res = PyObject_CallNoArgs(func);
270-
if (res == NULL) {
271-
return NULL;
270+
if (res == _Py_NULL) {
271+
return _Py_NULL;
272272
}
273273
assert(PyUnicode_Check(res));
274274
Py_DECREF(res);
275275

276276
// test PyObject_CallOneArg(): str(1) returns '1'
277277
PyObject *arg = PyLong_FromLong(1);
278-
if (arg == NULL) {
279-
return NULL;
278+
if (arg == _Py_NULL) {
279+
return _Py_NULL;
280280
}
281281
res = PyObject_CallOneArg(func, arg);
282282
Py_DECREF(arg);
283-
if (res == NULL) {
284-
return NULL;
283+
if (res == _Py_NULL) {
284+
return _Py_NULL;
285285
}
286286
assert(PyUnicode_Check(res));
287287
Py_DECREF(res);
@@ -334,7 +334,7 @@ test_module_add_type(PyObject *module)
334334
ASSERT_REFCNT(Py_REFCNT(type) == refcnt + 1);
335335

336336
PyObject *attr = PyObject_GetAttrString(module, type_name);
337-
if (attr == NULL) {
337+
if (attr == _Py_NULL) {
338338
return -1;
339339
}
340340
assert(attr == (PyObject *)type);
@@ -370,7 +370,7 @@ test_module_addobjectref(PyObject *module)
370370
ASSERT_REFCNT(Py_REFCNT(obj) == refcnt);
371371

372372
// PyModule_AddObjectRef() with value=NULL must not crash
373-
int res = PyModule_AddObjectRef(module, name, NULL);
373+
int res = PyModule_AddObjectRef(module, name, _Py_NULL);
374374
assert(res < 0);
375375
PyErr_Clear();
376376

@@ -382,8 +382,8 @@ static PyObject *
382382
test_module(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
383383
{
384384
PyObject *module = PyImport_ImportModule("sys");
385-
if (module == NULL) {
386-
return NULL;
385+
if (module == _Py_NULL) {
386+
return _Py_NULL;
387387
}
388388
assert(PyModule_Check(module));
389389

@@ -400,7 +400,7 @@ test_module(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
400400

401401
error:
402402
Py_DECREF(module);
403-
return NULL;
403+
return _Py_NULL;
404404
}
405405

406406

@@ -466,14 +466,14 @@ test_code(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
466466
{
467467
PyThreadState *tstate = PyThreadState_Get();
468468
PyFrameObject *frame = PyThreadState_GetFrame(tstate);
469-
if (frame == NULL) {
469+
if (frame == _Py_NULL) {
470470
PyErr_SetString(PyExc_AssertionError, "PyThreadState_GetFrame failed");
471-
return NULL;
471+
return _Py_NULL;
472472
}
473473
PyCodeObject *code = PyFrame_GetCode(frame);
474474

475475
PyObject *co_code = PyCode_GetCode(code);
476-
assert(co_code != NULL);
476+
assert(co_code != _Py_NULL);
477477
assert(PyBytes_Check(co_code));
478478
Py_DECREF(co_code);
479479

@@ -484,24 +484,85 @@ test_code(PyObject *Py_UNUSED(module), PyObject* Py_UNUSED(ignored))
484484
#endif
485485

486486

487+
#ifdef __cplusplus
488+
// Class to test operator casting an object to PyObject*
489+
class StrongRef
490+
{
491+
public:
492+
StrongRef(PyObject *obj) : m_obj(obj) {
493+
Py_INCREF(this->m_obj);
494+
}
495+
496+
~StrongRef() {
497+
Py_DECREF(this->m_obj);
498+
}
499+
500+
// Cast to PyObject*: get a borrowed reference
501+
inline operator PyObject*() const { return this->m_obj; }
502+
503+
private:
504+
PyObject *m_obj; // Strong reference
505+
};
506+
#endif
507+
508+
509+
static PyObject *
510+
test_api_casts(PyObject *Py_UNUSED(module), PyObject *Py_UNUSED(args))
511+
{
512+
PyObject *obj = Py_BuildValue("(ii)", 1, 2);
513+
if (obj == _Py_NULL) {
514+
return _Py_NULL;
515+
}
516+
Py_ssize_t refcnt = Py_REFCNT(obj);
517+
assert(refcnt >= 1);
518+
519+
// gh-92138: For backward compatibility, functions of Python C API accepts
520+
// "const PyObject*". Check that using it does not emit C++ compiler
521+
// warnings.
522+
const PyObject *const_obj = obj;
523+
Py_INCREF(const_obj);
524+
Py_DECREF(const_obj);
525+
PyTypeObject *type = Py_TYPE(const_obj);
526+
assert(Py_REFCNT(const_obj) == refcnt);
527+
assert(type == &PyTuple_Type);
528+
assert(PyTuple_GET_SIZE(const_obj) == 2);
529+
PyObject *one = PyTuple_GET_ITEM(const_obj, 0);
530+
assert(PyLong_AsLong(one) == 1);
531+
532+
#ifdef __cplusplus
533+
// gh-92898: StrongRef doesn't inherit from PyObject but has an operator to
534+
// cast to PyObject*.
535+
StrongRef strong_ref(obj);
536+
assert(Py_TYPE(strong_ref) == &PyTuple_Type);
537+
assert(Py_REFCNT(strong_ref) == (refcnt + 1));
538+
Py_INCREF(strong_ref);
539+
Py_DECREF(strong_ref);
540+
#endif
541+
542+
Py_DECREF(obj);
543+
Py_RETURN_NONE;
544+
}
545+
546+
487547
static struct PyMethodDef methods[] = {
488-
{"test_object", test_object, METH_NOARGS, NULL},
489-
{"test_py_is", test_py_is, METH_NOARGS, NULL},
548+
{"test_object", test_object, METH_NOARGS, _Py_NULL},
549+
{"test_py_is", test_py_is, METH_NOARGS, _Py_NULL},
490550
#if !defined(PYPY_VERSION)
491-
{"test_frame", test_frame, METH_NOARGS, NULL},
551+
{"test_frame", test_frame, METH_NOARGS, _Py_NULL},
492552
#endif
493-
{"test_thread_state", test_thread_state, METH_NOARGS, NULL},
494-
{"test_interpreter", test_interpreter, METH_NOARGS, NULL},
495-
{"test_calls", test_calls, METH_NOARGS, NULL},
496-
{"test_gc", test_gc, METH_NOARGS, NULL},
497-
{"test_module", test_module, METH_NOARGS, NULL},
553+
{"test_thread_state", test_thread_state, METH_NOARGS, _Py_NULL},
554+
{"test_interpreter", test_interpreter, METH_NOARGS, _Py_NULL},
555+
{"test_calls", test_calls, METH_NOARGS, _Py_NULL},
556+
{"test_gc", test_gc, METH_NOARGS, _Py_NULL},
557+
{"test_module", test_module, METH_NOARGS, _Py_NULL},
498558
#if (PY_VERSION_HEX <= 0x030B00A1 || 0x030B00A7 <= PY_VERSION_HEX) && !defined(PYPY_VERSION)
499-
{"test_float_pack", test_float_pack, METH_NOARGS, NULL},
559+
{"test_float_pack", test_float_pack, METH_NOARGS, _Py_NULL},
500560
#endif
501561
#if !defined(PYPY_VERSION)
502-
{"test_code", test_code, METH_NOARGS, NULL},
562+
{"test_code", test_code, METH_NOARGS, _Py_NULL},
503563
#endif
504-
{NULL, NULL, 0, NULL}
564+
{"test_api_casts", test_api_casts, METH_NOARGS, _Py_NULL},
565+
{_Py_NULL, _Py_NULL, 0, _Py_NULL}
505566
};
506567

507568

@@ -552,8 +613,8 @@ INIT_FUNC(void)
552613
{
553614
Py_InitModule4(MODULE_NAME_STR,
554615
methods,
555-
NULL,
556-
NULL,
616+
_Py_NULL,
617+
_Py_NULL,
557618
PYTHON_API_VERSION);
558619
}
559620
#endif

0 commit comments

Comments
 (0)