Skip to content

Commit 9f69a58

Browse files
gh-133767: Fix use-after-free in the unicode-escape decoder with an error handler (GH-129648)
If the error handler is used, a new bytes object is created to set as the object attribute of UnicodeDecodeError, and that bytes object then replaces the original data. A pointer to the decoded data will became invalid after destroying that temporary bytes object. So we need other way to return the first invalid escape from _PyUnicode_DecodeUnicodeEscapeInternal(). _PyBytes_DecodeEscape() does not have such issue, because it does not use the error handlers registry, but it should be changed for compatibility with _PyUnicode_DecodeUnicodeEscapeInternal().
1 parent 734e15b commit 9f69a58

File tree

8 files changed

+160
-63
lines changed

8 files changed

+160
-63
lines changed

Include/internal/pycore_bytesobject.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ extern PyObject* _PyBytes_FromHex(
2020

2121
// Helper for PyBytes_DecodeEscape that detects invalid escape chars.
2222
// Export for test_peg_generator.
23-
PyAPI_FUNC(PyObject*) _PyBytes_DecodeEscape(const char *, Py_ssize_t,
24-
const char *, const char **);
23+
PyAPI_FUNC(PyObject*) _PyBytes_DecodeEscape2(const char *, Py_ssize_t,
24+
const char *,
25+
int *, const char **);
2526

2627

2728
// Substring Search.

Include/internal/pycore_unicodeobject.h

+8-4
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,18 @@ extern PyObject* _PyUnicode_DecodeUnicodeEscapeStateful(
139139
// Helper for PyUnicode_DecodeUnicodeEscape that detects invalid escape
140140
// chars.
141141
// Export for test_peg_generator.
142-
PyAPI_FUNC(PyObject*) _PyUnicode_DecodeUnicodeEscapeInternal(
142+
PyAPI_FUNC(PyObject*) _PyUnicode_DecodeUnicodeEscapeInternal2(
143143
const char *string, /* Unicode-Escape encoded string */
144144
Py_ssize_t length, /* size of string */
145145
const char *errors, /* error handling */
146146
Py_ssize_t *consumed, /* bytes consumed */
147-
const char **first_invalid_escape); /* on return, points to first
148-
invalid escaped char in
149-
string. */
147+
int *first_invalid_escape_char, /* on return, if not -1, contain the first
148+
invalid escaped char (<= 0xff) or invalid
149+
octal escape (> 0xff) in string. */
150+
const char **first_invalid_escape_ptr); /* on return, if not NULL, may
151+
point to the first invalid escaped
152+
char in string.
153+
May be NULL if errors is not NULL. */
150154

151155
/* --- Raw-Unicode-Escape Codecs ---------------------------------------------- */
152156

Lib/test/test_codeccallbacks.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import codecs
33
import html.entities
44
import itertools
5+
import re
56
import sys
67
import unicodedata
78
import unittest
@@ -1125,7 +1126,7 @@ def test_bug828737(self):
11251126
text = 'abc<def>ghi'*n
11261127
text.translate(charmap)
11271128

1128-
def test_mutatingdecodehandler(self):
1129+
def test_mutating_decode_handler(self):
11291130
baddata = [
11301131
("ascii", b"\xff"),
11311132
("utf-7", b"++"),
@@ -1160,6 +1161,42 @@ def mutating(exc):
11601161
for (encoding, data) in baddata:
11611162
self.assertEqual(data.decode(encoding, "test.mutating"), "\u4242")
11621163

1164+
def test_mutating_decode_handler_unicode_escape(self):
1165+
decode = codecs.unicode_escape_decode
1166+
def mutating(exc):
1167+
if isinstance(exc, UnicodeDecodeError):
1168+
r = data.get(exc.object[:exc.end])
1169+
if r is not None:
1170+
exc.object = r[0] + exc.object[exc.end:]
1171+
return ('\u0404', r[1])
1172+
raise AssertionError("don't know how to handle %r" % exc)
1173+
1174+
codecs.register_error('test.mutating2', mutating)
1175+
data = {
1176+
br'\x0': (b'\\', 0),
1177+
br'\x3': (b'xxx\\', 3),
1178+
br'\x5': (b'x\\', 1),
1179+
}
1180+
def check(input, expected, msg):
1181+
with self.assertWarns(DeprecationWarning) as cm:
1182+
self.assertEqual(decode(input, 'test.mutating2'), (expected, len(input)))
1183+
self.assertIn(msg, str(cm.warning))
1184+
1185+
check(br'\x0n\z', '\u0404\n\\z', r'"\z" is an invalid escape sequence')
1186+
check(br'\x0n\501', '\u0404\n\u0141', r'"\501" is an invalid octal escape sequence')
1187+
check(br'\x0z', '\u0404\\z', r'"\z" is an invalid escape sequence')
1188+
1189+
check(br'\x3n\zr', '\u0404\n\\zr', r'"\z" is an invalid escape sequence')
1190+
check(br'\x3zr', '\u0404\\zr', r'"\z" is an invalid escape sequence')
1191+
check(br'\x3z5', '\u0404\\z5', r'"\z" is an invalid escape sequence')
1192+
check(memoryview(br'\x3z5x')[:-1], '\u0404\\z5', r'"\z" is an invalid escape sequence')
1193+
check(memoryview(br'\x3z5xy')[:-2], '\u0404\\z5', r'"\z" is an invalid escape sequence')
1194+
1195+
check(br'\x5n\z', '\u0404\n\\z', r'"\z" is an invalid escape sequence')
1196+
check(br'\x5n\501', '\u0404\n\u0141', r'"\501" is an invalid octal escape sequence')
1197+
check(br'\x5z', '\u0404\\z', r'"\z" is an invalid escape sequence')
1198+
check(memoryview(br'\x5zy')[:-1], '\u0404\\z', r'"\z" is an invalid escape sequence')
1199+
11631200
# issue32583
11641201
def test_crashing_decode_handler(self):
11651202
# better generating one more character to fill the extra space slot

Lib/test/test_codecs.py

+42-10
Original file line numberDiff line numberDiff line change
@@ -1196,23 +1196,39 @@ def test_escape(self):
11961196
check(br"[\1010]", b"[A0]")
11971197
check(br"[\x41]", b"[A]")
11981198
check(br"[\x410]", b"[A0]")
1199+
1200+
def test_warnings(self):
1201+
decode = codecs.escape_decode
1202+
check = coding_checker(self, decode)
11991203
for i in range(97, 123):
12001204
b = bytes([i])
12011205
if b not in b'abfnrtvx':
1202-
with self.assertWarns(DeprecationWarning):
1206+
with self.assertWarnsRegex(DeprecationWarning,
1207+
r'"\\%c" is an invalid escape sequence' % i):
12031208
check(b"\\" + b, b"\\" + b)
1204-
with self.assertWarns(DeprecationWarning):
1209+
with self.assertWarnsRegex(DeprecationWarning,
1210+
r'"\\%c" is an invalid escape sequence' % (i-32)):
12051211
check(b"\\" + b.upper(), b"\\" + b.upper())
1206-
with self.assertWarns(DeprecationWarning):
1212+
with self.assertWarnsRegex(DeprecationWarning,
1213+
r'"\\8" is an invalid escape sequence'):
12071214
check(br"\8", b"\\8")
12081215
with self.assertWarns(DeprecationWarning):
12091216
check(br"\9", b"\\9")
1210-
with self.assertWarns(DeprecationWarning):
1217+
with self.assertWarnsRegex(DeprecationWarning,
1218+
r'"\\\xfa" is an invalid escape sequence') as cm:
12111219
check(b"\\\xfa", b"\\\xfa")
12121220
for i in range(0o400, 0o1000):
1213-
with self.assertWarns(DeprecationWarning):
1221+
with self.assertWarnsRegex(DeprecationWarning,
1222+
r'"\\%o" is an invalid octal escape sequence' % i):
12141223
check(rb'\%o' % i, bytes([i & 0o377]))
12151224

1225+
with self.assertWarnsRegex(DeprecationWarning,
1226+
r'"\\z" is an invalid escape sequence'):
1227+
self.assertEqual(decode(br'\x\z', 'ignore'), (b'\\z', 4))
1228+
with self.assertWarnsRegex(DeprecationWarning,
1229+
r'"\\501" is an invalid octal escape sequence'):
1230+
self.assertEqual(decode(br'\x\501', 'ignore'), (b'A', 6))
1231+
12161232
def test_errors(self):
12171233
decode = codecs.escape_decode
12181234
self.assertRaises(ValueError, decode, br"\x")
@@ -2661,24 +2677,40 @@ def test_escape_decode(self):
26612677
check(br"[\x410]", "[A0]")
26622678
check(br"\u20ac", "\u20ac")
26632679
check(br"\U0001d120", "\U0001d120")
2680+
2681+
def test_decode_warnings(self):
2682+
decode = codecs.unicode_escape_decode
2683+
check = coding_checker(self, decode)
26642684
for i in range(97, 123):
26652685
b = bytes([i])
26662686
if b not in b'abfnrtuvx':
2667-
with self.assertWarns(DeprecationWarning):
2687+
with self.assertWarnsRegex(DeprecationWarning,
2688+
r'"\\%c" is an invalid escape sequence' % i):
26682689
check(b"\\" + b, "\\" + chr(i))
26692690
if b.upper() not in b'UN':
2670-
with self.assertWarns(DeprecationWarning):
2691+
with self.assertWarnsRegex(DeprecationWarning,
2692+
r'"\\%c" is an invalid escape sequence' % (i-32)):
26712693
check(b"\\" + b.upper(), "\\" + chr(i-32))
2672-
with self.assertWarns(DeprecationWarning):
2694+
with self.assertWarnsRegex(DeprecationWarning,
2695+
r'"\\8" is an invalid escape sequence'):
26732696
check(br"\8", "\\8")
26742697
with self.assertWarns(DeprecationWarning):
26752698
check(br"\9", "\\9")
2676-
with self.assertWarns(DeprecationWarning):
2699+
with self.assertWarnsRegex(DeprecationWarning,
2700+
r'"\\\xfa" is an invalid escape sequence') as cm:
26772701
check(b"\\\xfa", "\\\xfa")
26782702
for i in range(0o400, 0o1000):
2679-
with self.assertWarns(DeprecationWarning):
2703+
with self.assertWarnsRegex(DeprecationWarning,
2704+
r'"\\%o" is an invalid octal escape sequence' % i):
26802705
check(rb'\%o' % i, chr(i))
26812706

2707+
with self.assertWarnsRegex(DeprecationWarning,
2708+
r'"\\z" is an invalid escape sequence'):
2709+
self.assertEqual(decode(br'\x\z', 'ignore'), ('\\z', 4))
2710+
with self.assertWarnsRegex(DeprecationWarning,
2711+
r'"\\501" is an invalid octal escape sequence'):
2712+
self.assertEqual(decode(br'\x\501', 'ignore'), ('\u0141', 6))
2713+
26822714
def test_decode_errors(self):
26832715
decode = codecs.unicode_escape_decode
26842716
for c, d in (b'x', 2), (b'u', 4), (b'U', 4):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix use-after-free in the "unicode-escape" decoder with a non-"strict" error
2+
handler.

Objects/bytesobject.c

+23-18
Original file line numberDiff line numberDiff line change
@@ -1075,10 +1075,11 @@ _PyBytes_FormatEx(const char *format, Py_ssize_t format_len,
10751075
}
10761076

10771077
/* Unescape a backslash-escaped string. */
1078-
PyObject *_PyBytes_DecodeEscape(const char *s,
1078+
PyObject *_PyBytes_DecodeEscape2(const char *s,
10791079
Py_ssize_t len,
10801080
const char *errors,
1081-
const char **first_invalid_escape)
1081+
int *first_invalid_escape_char,
1082+
const char **first_invalid_escape_ptr)
10821083
{
10831084
int c;
10841085
char *p;
@@ -1092,7 +1093,8 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
10921093
return NULL;
10931094
writer.overallocate = 1;
10941095

1095-
*first_invalid_escape = NULL;
1096+
*first_invalid_escape_char = -1;
1097+
*first_invalid_escape_ptr = NULL;
10961098

10971099
end = s + len;
10981100
while (s < end) {
@@ -1130,9 +1132,10 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
11301132
c = (c<<3) + *s++ - '0';
11311133
}
11321134
if (c > 0377) {
1133-
if (*first_invalid_escape == NULL) {
1134-
*first_invalid_escape = s-3; /* Back up 3 chars, since we've
1135-
already incremented s. */
1135+
if (*first_invalid_escape_char == -1) {
1136+
*first_invalid_escape_char = c;
1137+
/* Back up 3 chars, since we've already incremented s. */
1138+
*first_invalid_escape_ptr = s - 3;
11361139
}
11371140
}
11381141
*p++ = c;
@@ -1173,9 +1176,10 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
11731176
break;
11741177

11751178
default:
1176-
if (*first_invalid_escape == NULL) {
1177-
*first_invalid_escape = s-1; /* Back up one char, since we've
1178-
already incremented s. */
1179+
if (*first_invalid_escape_char == -1) {
1180+
*first_invalid_escape_char = (unsigned char)s[-1];
1181+
/* Back up one char, since we've already incremented s. */
1182+
*first_invalid_escape_ptr = s - 1;
11791183
}
11801184
*p++ = '\\';
11811185
s--;
@@ -1195,18 +1199,19 @@ PyObject *PyBytes_DecodeEscape(const char *s,
11951199
Py_ssize_t Py_UNUSED(unicode),
11961200
const char *Py_UNUSED(recode_encoding))
11971201
{
1198-
const char* first_invalid_escape;
1199-
PyObject *result = _PyBytes_DecodeEscape(s, len, errors,
1200-
&first_invalid_escape);
1202+
int first_invalid_escape_char;
1203+
const char *first_invalid_escape_ptr;
1204+
PyObject *result = _PyBytes_DecodeEscape2(s, len, errors,
1205+
&first_invalid_escape_char,
1206+
&first_invalid_escape_ptr);
12011207
if (result == NULL)
12021208
return NULL;
1203-
if (first_invalid_escape != NULL) {
1204-
unsigned char c = *first_invalid_escape;
1205-
if ('4' <= c && c <= '7') {
1209+
if (first_invalid_escape_char != -1) {
1210+
if (first_invalid_escape_char > 0xff) {
12061211
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
1207-
"b\"\\%.3s\" is an invalid octal escape sequence. "
1212+
"b\"\\%o\" is an invalid octal escape sequence. "
12081213
"Such sequences will not work in the future. ",
1209-
first_invalid_escape) < 0)
1214+
first_invalid_escape_char) < 0)
12101215
{
12111216
Py_DECREF(result);
12121217
return NULL;
@@ -1216,7 +1221,7 @@ PyObject *PyBytes_DecodeEscape(const char *s,
12161221
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
12171222
"b\"\\%c\" is an invalid escape sequence. "
12181223
"Such sequences will not work in the future. ",
1219-
c) < 0)
1224+
first_invalid_escape_char) < 0)
12201225
{
12211226
Py_DECREF(result);
12221227
return NULL;

Objects/unicodeobject.c

+28-18
Original file line numberDiff line numberDiff line change
@@ -6596,21 +6596,24 @@ _PyUnicode_GetNameCAPI(void)
65966596
/* --- Unicode Escape Codec ----------------------------------------------- */
65976597

65986598
PyObject *
6599-
_PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
6599+
_PyUnicode_DecodeUnicodeEscapeInternal2(const char *s,
66006600
Py_ssize_t size,
66016601
const char *errors,
66026602
Py_ssize_t *consumed,
6603-
const char **first_invalid_escape)
6603+
int *first_invalid_escape_char,
6604+
const char **first_invalid_escape_ptr)
66046605
{
66056606
const char *starts = s;
6607+
const char *initial_starts = starts;
66066608
_PyUnicodeWriter writer;
66076609
const char *end;
66086610
PyObject *errorHandler = NULL;
66096611
PyObject *exc = NULL;
66106612
_PyUnicode_Name_CAPI *ucnhash_capi;
66116613

66126614
// so we can remember if we've seen an invalid escape char or not
6613-
*first_invalid_escape = NULL;
6615+
*first_invalid_escape_char = -1;
6616+
*first_invalid_escape_ptr = NULL;
66146617

66156618
if (size == 0) {
66166619
if (consumed) {
@@ -6698,9 +6701,12 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
66986701
}
66996702
}
67006703
if (ch > 0377) {
6701-
if (*first_invalid_escape == NULL) {
6702-
*first_invalid_escape = s-3; /* Back up 3 chars, since we've
6703-
already incremented s. */
6704+
if (*first_invalid_escape_char == -1) {
6705+
*first_invalid_escape_char = ch;
6706+
if (starts == initial_starts) {
6707+
/* Back up 3 chars, since we've already incremented s. */
6708+
*first_invalid_escape_ptr = s - 3;
6709+
}
67046710
}
67056711
}
67066712
WRITE_CHAR(ch);
@@ -6795,9 +6801,12 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
67956801
goto error;
67966802

67976803
default:
6798-
if (*first_invalid_escape == NULL) {
6799-
*first_invalid_escape = s-1; /* Back up one char, since we've
6800-
already incremented s. */
6804+
if (*first_invalid_escape_char == -1) {
6805+
*first_invalid_escape_char = c;
6806+
if (starts == initial_starts) {
6807+
/* Back up one char, since we've already incremented s. */
6808+
*first_invalid_escape_ptr = s - 1;
6809+
}
68016810
}
68026811
WRITE_ASCII_CHAR('\\');
68036812
WRITE_CHAR(c);
@@ -6842,19 +6851,20 @@ _PyUnicode_DecodeUnicodeEscapeStateful(const char *s,
68426851
const char *errors,
68436852
Py_ssize_t *consumed)
68446853
{
6845-
const char *first_invalid_escape;
6846-
PyObject *result = _PyUnicode_DecodeUnicodeEscapeInternal(s, size, errors,
6854+
int first_invalid_escape_char;
6855+
const char *first_invalid_escape_ptr;
6856+
PyObject *result = _PyUnicode_DecodeUnicodeEscapeInternal2(s, size, errors,
68476857
consumed,
6848-
&first_invalid_escape);
6858+
&first_invalid_escape_char,
6859+
&first_invalid_escape_ptr);
68496860
if (result == NULL)
68506861
return NULL;
6851-
if (first_invalid_escape != NULL) {
6852-
unsigned char c = *first_invalid_escape;
6853-
if ('4' <= c && c <= '7') {
6862+
if (first_invalid_escape_char != -1) {
6863+
if (first_invalid_escape_char > 0xff) {
68546864
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
6855-
"\"\\%.3s\" is an invalid octal escape sequence. "
6865+
"\"\\%o\" is an invalid octal escape sequence. "
68566866
"Such sequences will not work in the future. ",
6857-
first_invalid_escape) < 0)
6867+
first_invalid_escape_char) < 0)
68586868
{
68596869
Py_DECREF(result);
68606870
return NULL;
@@ -6864,7 +6874,7 @@ _PyUnicode_DecodeUnicodeEscapeStateful(const char *s,
68646874
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
68656875
"\"\\%c\" is an invalid escape sequence. "
68666876
"Such sequences will not work in the future. ",
6867-
c) < 0)
6877+
first_invalid_escape_char) < 0)
68686878
{
68696879
Py_DECREF(result);
68706880
return NULL;

0 commit comments

Comments
 (0)