gh-133767: Fix use-after-free in the unicode-escape decoder with an error handler (GH-129648)

If the error handler is used, a new bytes object is created to set as
the object attribute of UnicodeDecodeError, and that bytes object then
replaces the original data. A pointer to the decoded data will became invalid
after destroying that temporary bytes object. So we need other way to return
the first invalid escape from _PyUnicode_DecodeUnicodeEscapeInternal().

_PyBytes_DecodeEscape() does not have such issue, because it does not
use the error handlers registry, but it should be changed for compatibility
with _PyUnicode_DecodeUnicodeEscapeInternal().
This commit is contained in:
Serhiy Storchaka 2025-05-12 20:42:23 +03:00 committed by GitHub
parent 734e15b70d
commit 9f69a58623
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 160 additions and 63 deletions

View file

@ -20,8 +20,9 @@ extern PyObject* _PyBytes_FromHex(
// Helper for PyBytes_DecodeEscape that detects invalid escape chars. // Helper for PyBytes_DecodeEscape that detects invalid escape chars.
// Export for test_peg_generator. // Export for test_peg_generator.
PyAPI_FUNC(PyObject*) _PyBytes_DecodeEscape(const char *, Py_ssize_t, PyAPI_FUNC(PyObject*) _PyBytes_DecodeEscape2(const char *, Py_ssize_t,
const char *, const char **); const char *,
int *, const char **);
// Substring Search. // Substring Search.

View file

@ -139,14 +139,18 @@ extern PyObject* _PyUnicode_DecodeUnicodeEscapeStateful(
// Helper for PyUnicode_DecodeUnicodeEscape that detects invalid escape // Helper for PyUnicode_DecodeUnicodeEscape that detects invalid escape
// chars. // chars.
// Export for test_peg_generator. // Export for test_peg_generator.
PyAPI_FUNC(PyObject*) _PyUnicode_DecodeUnicodeEscapeInternal( PyAPI_FUNC(PyObject*) _PyUnicode_DecodeUnicodeEscapeInternal2(
const char *string, /* Unicode-Escape encoded string */ const char *string, /* Unicode-Escape encoded string */
Py_ssize_t length, /* size of string */ Py_ssize_t length, /* size of string */
const char *errors, /* error handling */ const char *errors, /* error handling */
Py_ssize_t *consumed, /* bytes consumed */ Py_ssize_t *consumed, /* bytes consumed */
const char **first_invalid_escape); /* on return, points to first int *first_invalid_escape_char, /* on return, if not -1, contain the first
invalid escaped char in invalid escaped char (<= 0xff) or invalid
string. */ octal escape (> 0xff) in string. */
const char **first_invalid_escape_ptr); /* on return, if not NULL, may
point to the first invalid escaped
char in string.
May be NULL if errors is not NULL. */
/* --- Raw-Unicode-Escape Codecs ---------------------------------------------- */ /* --- Raw-Unicode-Escape Codecs ---------------------------------------------- */

View file

@ -2,6 +2,7 @@ from _codecs import _unregister_error as _codecs_unregister_error
import codecs import codecs
import html.entities import html.entities
import itertools import itertools
import re
import sys import sys
import unicodedata import unicodedata
import unittest import unittest
@ -1125,7 +1126,7 @@ class CodecCallbackTest(unittest.TestCase):
text = 'abc<def>ghi'*n text = 'abc<def>ghi'*n
text.translate(charmap) text.translate(charmap)
def test_mutatingdecodehandler(self): def test_mutating_decode_handler(self):
baddata = [ baddata = [
("ascii", b"\xff"), ("ascii", b"\xff"),
("utf-7", b"++"), ("utf-7", b"++"),
@ -1160,6 +1161,42 @@ class CodecCallbackTest(unittest.TestCase):
for (encoding, data) in baddata: for (encoding, data) in baddata:
self.assertEqual(data.decode(encoding, "test.mutating"), "\u4242") self.assertEqual(data.decode(encoding, "test.mutating"), "\u4242")
def test_mutating_decode_handler_unicode_escape(self):
decode = codecs.unicode_escape_decode
def mutating(exc):
if isinstance(exc, UnicodeDecodeError):
r = data.get(exc.object[:exc.end])
if r is not None:
exc.object = r[0] + exc.object[exc.end:]
return ('\u0404', r[1])
raise AssertionError("don't know how to handle %r" % exc)
codecs.register_error('test.mutating2', mutating)
data = {
br'\x0': (b'\\', 0),
br'\x3': (b'xxx\\', 3),
br'\x5': (b'x\\', 1),
}
def check(input, expected, msg):
with self.assertWarns(DeprecationWarning) as cm:
self.assertEqual(decode(input, 'test.mutating2'), (expected, len(input)))
self.assertIn(msg, str(cm.warning))
check(br'\x0n\z', '\u0404\n\\z', r'"\z" is an invalid escape sequence')
check(br'\x0n\501', '\u0404\n\u0141', r'"\501" is an invalid octal escape sequence')
check(br'\x0z', '\u0404\\z', r'"\z" is an invalid escape sequence')
check(br'\x3n\zr', '\u0404\n\\zr', r'"\z" is an invalid escape sequence')
check(br'\x3zr', '\u0404\\zr', r'"\z" is an invalid escape sequence')
check(br'\x3z5', '\u0404\\z5', r'"\z" is an invalid escape sequence')
check(memoryview(br'\x3z5x')[:-1], '\u0404\\z5', r'"\z" is an invalid escape sequence')
check(memoryview(br'\x3z5xy')[:-2], '\u0404\\z5', r'"\z" is an invalid escape sequence')
check(br'\x5n\z', '\u0404\n\\z', r'"\z" is an invalid escape sequence')
check(br'\x5n\501', '\u0404\n\u0141', r'"\501" is an invalid octal escape sequence')
check(br'\x5z', '\u0404\\z', r'"\z" is an invalid escape sequence')
check(memoryview(br'\x5zy')[:-1], '\u0404\\z', r'"\z" is an invalid escape sequence')
# issue32583 # issue32583
def test_crashing_decode_handler(self): def test_crashing_decode_handler(self):
# better generating one more character to fill the extra space slot # better generating one more character to fill the extra space slot

View file

@ -1196,23 +1196,39 @@ class EscapeDecodeTest(unittest.TestCase):
check(br"[\1010]", b"[A0]") check(br"[\1010]", b"[A0]")
check(br"[\x41]", b"[A]") check(br"[\x41]", b"[A]")
check(br"[\x410]", b"[A0]") check(br"[\x410]", b"[A0]")
def test_warnings(self):
decode = codecs.escape_decode
check = coding_checker(self, decode)
for i in range(97, 123): for i in range(97, 123):
b = bytes([i]) b = bytes([i])
if b not in b'abfnrtvx': if b not in b'abfnrtvx':
with self.assertWarns(DeprecationWarning): with self.assertWarnsRegex(DeprecationWarning,
r'"\\%c" is an invalid escape sequence' % i):
check(b"\\" + b, b"\\" + b) check(b"\\" + b, b"\\" + b)
with self.assertWarns(DeprecationWarning): with self.assertWarnsRegex(DeprecationWarning,
r'"\\%c" is an invalid escape sequence' % (i-32)):
check(b"\\" + b.upper(), b"\\" + b.upper()) check(b"\\" + b.upper(), b"\\" + b.upper())
with self.assertWarns(DeprecationWarning): with self.assertWarnsRegex(DeprecationWarning,
r'"\\8" is an invalid escape sequence'):
check(br"\8", b"\\8") check(br"\8", b"\\8")
with self.assertWarns(DeprecationWarning): with self.assertWarns(DeprecationWarning):
check(br"\9", b"\\9") check(br"\9", b"\\9")
with self.assertWarns(DeprecationWarning): with self.assertWarnsRegex(DeprecationWarning,
r'"\\\xfa" is an invalid escape sequence') as cm:
check(b"\\\xfa", b"\\\xfa") check(b"\\\xfa", b"\\\xfa")
for i in range(0o400, 0o1000): for i in range(0o400, 0o1000):
with self.assertWarns(DeprecationWarning): with self.assertWarnsRegex(DeprecationWarning,
r'"\\%o" is an invalid octal escape sequence' % i):
check(rb'\%o' % i, bytes([i & 0o377])) check(rb'\%o' % i, bytes([i & 0o377]))
with self.assertWarnsRegex(DeprecationWarning,
r'"\\z" is an invalid escape sequence'):
self.assertEqual(decode(br'\x\z', 'ignore'), (b'\\z', 4))
with self.assertWarnsRegex(DeprecationWarning,
r'"\\501" is an invalid octal escape sequence'):
self.assertEqual(decode(br'\x\501', 'ignore'), (b'A', 6))
def test_errors(self): def test_errors(self):
decode = codecs.escape_decode decode = codecs.escape_decode
self.assertRaises(ValueError, decode, br"\x") self.assertRaises(ValueError, decode, br"\x")
@ -2661,24 +2677,40 @@ class UnicodeEscapeTest(ReadTest, unittest.TestCase):
check(br"[\x410]", "[A0]") check(br"[\x410]", "[A0]")
check(br"\u20ac", "\u20ac") check(br"\u20ac", "\u20ac")
check(br"\U0001d120", "\U0001d120") check(br"\U0001d120", "\U0001d120")
def test_decode_warnings(self):
decode = codecs.unicode_escape_decode
check = coding_checker(self, decode)
for i in range(97, 123): for i in range(97, 123):
b = bytes([i]) b = bytes([i])
if b not in b'abfnrtuvx': if b not in b'abfnrtuvx':
with self.assertWarns(DeprecationWarning): with self.assertWarnsRegex(DeprecationWarning,
r'"\\%c" is an invalid escape sequence' % i):
check(b"\\" + b, "\\" + chr(i)) check(b"\\" + b, "\\" + chr(i))
if b.upper() not in b'UN': if b.upper() not in b'UN':
with self.assertWarns(DeprecationWarning): with self.assertWarnsRegex(DeprecationWarning,
r'"\\%c" is an invalid escape sequence' % (i-32)):
check(b"\\" + b.upper(), "\\" + chr(i-32)) check(b"\\" + b.upper(), "\\" + chr(i-32))
with self.assertWarns(DeprecationWarning): with self.assertWarnsRegex(DeprecationWarning,
r'"\\8" is an invalid escape sequence'):
check(br"\8", "\\8") check(br"\8", "\\8")
with self.assertWarns(DeprecationWarning): with self.assertWarns(DeprecationWarning):
check(br"\9", "\\9") check(br"\9", "\\9")
with self.assertWarns(DeprecationWarning): with self.assertWarnsRegex(DeprecationWarning,
r'"\\\xfa" is an invalid escape sequence') as cm:
check(b"\\\xfa", "\\\xfa") check(b"\\\xfa", "\\\xfa")
for i in range(0o400, 0o1000): for i in range(0o400, 0o1000):
with self.assertWarns(DeprecationWarning): with self.assertWarnsRegex(DeprecationWarning,
r'"\\%o" is an invalid octal escape sequence' % i):
check(rb'\%o' % i, chr(i)) check(rb'\%o' % i, chr(i))
with self.assertWarnsRegex(DeprecationWarning,
r'"\\z" is an invalid escape sequence'):
self.assertEqual(decode(br'\x\z', 'ignore'), ('\\z', 4))
with self.assertWarnsRegex(DeprecationWarning,
r'"\\501" is an invalid octal escape sequence'):
self.assertEqual(decode(br'\x\501', 'ignore'), ('\u0141', 6))
def test_decode_errors(self): def test_decode_errors(self):
decode = codecs.unicode_escape_decode decode = codecs.unicode_escape_decode
for c, d in (b'x', 2), (b'u', 4), (b'U', 4): for c, d in (b'x', 2), (b'u', 4), (b'U', 4):

View file

@ -0,0 +1,2 @@
Fix use-after-free in the "unicode-escape" decoder with a non-"strict" error
handler.

View file

@ -1075,10 +1075,11 @@ _PyBytes_FormatEx(const char *format, Py_ssize_t format_len,
} }
/* Unescape a backslash-escaped string. */ /* Unescape a backslash-escaped string. */
PyObject *_PyBytes_DecodeEscape(const char *s, PyObject *_PyBytes_DecodeEscape2(const char *s,
Py_ssize_t len, Py_ssize_t len,
const char *errors, const char *errors,
const char **first_invalid_escape) int *first_invalid_escape_char,
const char **first_invalid_escape_ptr)
{ {
int c; int c;
char *p; char *p;
@ -1092,7 +1093,8 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
return NULL; return NULL;
writer.overallocate = 1; writer.overallocate = 1;
*first_invalid_escape = NULL; *first_invalid_escape_char = -1;
*first_invalid_escape_ptr = NULL;
end = s + len; end = s + len;
while (s < end) { while (s < end) {
@ -1130,9 +1132,10 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
c = (c<<3) + *s++ - '0'; c = (c<<3) + *s++ - '0';
} }
if (c > 0377) { if (c > 0377) {
if (*first_invalid_escape == NULL) { if (*first_invalid_escape_char == -1) {
*first_invalid_escape = s-3; /* Back up 3 chars, since we've *first_invalid_escape_char = c;
already incremented s. */ /* Back up 3 chars, since we've already incremented s. */
*first_invalid_escape_ptr = s - 3;
} }
} }
*p++ = c; *p++ = c;
@ -1173,9 +1176,10 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
break; break;
default: default:
if (*first_invalid_escape == NULL) { if (*first_invalid_escape_char == -1) {
*first_invalid_escape = s-1; /* Back up one char, since we've *first_invalid_escape_char = (unsigned char)s[-1];
already incremented s. */ /* Back up one char, since we've already incremented s. */
*first_invalid_escape_ptr = s - 1;
} }
*p++ = '\\'; *p++ = '\\';
s--; s--;
@ -1195,18 +1199,19 @@ PyObject *PyBytes_DecodeEscape(const char *s,
Py_ssize_t Py_UNUSED(unicode), Py_ssize_t Py_UNUSED(unicode),
const char *Py_UNUSED(recode_encoding)) const char *Py_UNUSED(recode_encoding))
{ {
const char* first_invalid_escape; int first_invalid_escape_char;
PyObject *result = _PyBytes_DecodeEscape(s, len, errors, const char *first_invalid_escape_ptr;
&first_invalid_escape); PyObject *result = _PyBytes_DecodeEscape2(s, len, errors,
&first_invalid_escape_char,
&first_invalid_escape_ptr);
if (result == NULL) if (result == NULL)
return NULL; return NULL;
if (first_invalid_escape != NULL) { if (first_invalid_escape_char != -1) {
unsigned char c = *first_invalid_escape; if (first_invalid_escape_char > 0xff) {
if ('4' <= c && c <= '7') {
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
"b\"\\%.3s\" is an invalid octal escape sequence. " "b\"\\%o\" is an invalid octal escape sequence. "
"Such sequences will not work in the future. ", "Such sequences will not work in the future. ",
first_invalid_escape) < 0) first_invalid_escape_char) < 0)
{ {
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;
@ -1216,7 +1221,7 @@ PyObject *PyBytes_DecodeEscape(const char *s,
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
"b\"\\%c\" is an invalid escape sequence. " "b\"\\%c\" is an invalid escape sequence. "
"Such sequences will not work in the future. ", "Such sequences will not work in the future. ",
c) < 0) first_invalid_escape_char) < 0)
{ {
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;

View file

@ -6596,13 +6596,15 @@ _PyUnicode_GetNameCAPI(void)
/* --- Unicode Escape Codec ----------------------------------------------- */ /* --- Unicode Escape Codec ----------------------------------------------- */
PyObject * PyObject *
_PyUnicode_DecodeUnicodeEscapeInternal(const char *s, _PyUnicode_DecodeUnicodeEscapeInternal2(const char *s,
Py_ssize_t size, Py_ssize_t size,
const char *errors, const char *errors,
Py_ssize_t *consumed, Py_ssize_t *consumed,
const char **first_invalid_escape) int *first_invalid_escape_char,
const char **first_invalid_escape_ptr)
{ {
const char *starts = s; const char *starts = s;
const char *initial_starts = starts;
_PyUnicodeWriter writer; _PyUnicodeWriter writer;
const char *end; const char *end;
PyObject *errorHandler = NULL; PyObject *errorHandler = NULL;
@ -6610,7 +6612,8 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
_PyUnicode_Name_CAPI *ucnhash_capi; _PyUnicode_Name_CAPI *ucnhash_capi;
// so we can remember if we've seen an invalid escape char or not // so we can remember if we've seen an invalid escape char or not
*first_invalid_escape = NULL; *first_invalid_escape_char = -1;
*first_invalid_escape_ptr = NULL;
if (size == 0) { if (size == 0) {
if (consumed) { if (consumed) {
@ -6698,9 +6701,12 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
} }
} }
if (ch > 0377) { if (ch > 0377) {
if (*first_invalid_escape == NULL) { if (*first_invalid_escape_char == -1) {
*first_invalid_escape = s-3; /* Back up 3 chars, since we've *first_invalid_escape_char = ch;
already incremented s. */ if (starts == initial_starts) {
/* Back up 3 chars, since we've already incremented s. */
*first_invalid_escape_ptr = s - 3;
}
} }
} }
WRITE_CHAR(ch); WRITE_CHAR(ch);
@ -6795,9 +6801,12 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
goto error; goto error;
default: default:
if (*first_invalid_escape == NULL) { if (*first_invalid_escape_char == -1) {
*first_invalid_escape = s-1; /* Back up one char, since we've *first_invalid_escape_char = c;
already incremented s. */ if (starts == initial_starts) {
/* Back up one char, since we've already incremented s. */
*first_invalid_escape_ptr = s - 1;
}
} }
WRITE_ASCII_CHAR('\\'); WRITE_ASCII_CHAR('\\');
WRITE_CHAR(c); WRITE_CHAR(c);
@ -6842,19 +6851,20 @@ _PyUnicode_DecodeUnicodeEscapeStateful(const char *s,
const char *errors, const char *errors,
Py_ssize_t *consumed) Py_ssize_t *consumed)
{ {
const char *first_invalid_escape; int first_invalid_escape_char;
PyObject *result = _PyUnicode_DecodeUnicodeEscapeInternal(s, size, errors, const char *first_invalid_escape_ptr;
PyObject *result = _PyUnicode_DecodeUnicodeEscapeInternal2(s, size, errors,
consumed, consumed,
&first_invalid_escape); &first_invalid_escape_char,
&first_invalid_escape_ptr);
if (result == NULL) if (result == NULL)
return NULL; return NULL;
if (first_invalid_escape != NULL) { if (first_invalid_escape_char != -1) {
unsigned char c = *first_invalid_escape; if (first_invalid_escape_char > 0xff) {
if ('4' <= c && c <= '7') {
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
"\"\\%.3s\" is an invalid octal escape sequence. " "\"\\%o\" is an invalid octal escape sequence. "
"Such sequences will not work in the future. ", "Such sequences will not work in the future. ",
first_invalid_escape) < 0) first_invalid_escape_char) < 0)
{ {
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;
@ -6864,7 +6874,7 @@ _PyUnicode_DecodeUnicodeEscapeStateful(const char *s,
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
"\"\\%c\" is an invalid escape sequence. " "\"\\%c\" is an invalid escape sequence. "
"Such sequences will not work in the future. ", "Such sequences will not work in the future. ",
c) < 0) first_invalid_escape_char) < 0)
{ {
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;

View file

@ -196,15 +196,18 @@ decode_unicode_with_escapes(Parser *parser, const char *s, size_t len, Token *t)
len = (size_t)(p - buf); len = (size_t)(p - buf);
s = buf; s = buf;
const char *first_invalid_escape; int first_invalid_escape_char;
v = _PyUnicode_DecodeUnicodeEscapeInternal(s, (Py_ssize_t)len, NULL, NULL, &first_invalid_escape); const char *first_invalid_escape_ptr;
v = _PyUnicode_DecodeUnicodeEscapeInternal2(s, (Py_ssize_t)len, NULL, NULL,
&first_invalid_escape_char,
&first_invalid_escape_ptr);
// HACK: later we can simply pass the line no, since we don't preserve the tokens // HACK: later we can simply pass the line no, since we don't preserve the tokens
// when we are decoding the string but we preserve the line numbers. // when we are decoding the string but we preserve the line numbers.
if (v != NULL && first_invalid_escape != NULL && t != NULL) { if (v != NULL && first_invalid_escape_ptr != NULL && t != NULL) {
if (warn_invalid_escape_sequence(parser, s, first_invalid_escape, t) < 0) { if (warn_invalid_escape_sequence(parser, s, first_invalid_escape_ptr, t) < 0) {
/* We have not decref u before because first_invalid_escape points /* We have not decref u before because first_invalid_escape_ptr
inside u. */ points inside u. */
Py_XDECREF(u); Py_XDECREF(u);
Py_DECREF(v); Py_DECREF(v);
return NULL; return NULL;
@ -217,14 +220,17 @@ decode_unicode_with_escapes(Parser *parser, const char *s, size_t len, Token *t)
static PyObject * static PyObject *
decode_bytes_with_escapes(Parser *p, const char *s, Py_ssize_t len, Token *t) decode_bytes_with_escapes(Parser *p, const char *s, Py_ssize_t len, Token *t)
{ {
const char *first_invalid_escape; int first_invalid_escape_char;
PyObject *result = _PyBytes_DecodeEscape(s, len, NULL, &first_invalid_escape); const char *first_invalid_escape_ptr;
PyObject *result = _PyBytes_DecodeEscape2(s, len, NULL,
&first_invalid_escape_char,
&first_invalid_escape_ptr);
if (result == NULL) { if (result == NULL) {
return NULL; return NULL;
} }
if (first_invalid_escape != NULL) { if (first_invalid_escape_ptr != NULL) {
if (warn_invalid_escape_sequence(p, s, first_invalid_escape, t) < 0) { if (warn_invalid_escape_sequence(p, s, first_invalid_escape_ptr, t) < 0) {
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;
} }