[3.12] gh-133767: Fix use-after-free in the unicode-escape decoder with an error handler (GH-129648) (GH-133944) (#134337)

If the error handler is used, a new bytes object is created to set as
the object attribute of UnicodeDecodeError, and that bytes object then
replaces the original data. A pointer to the decoded data will became invalid
after destroying that temporary bytes object. So we need other way to return
the first invalid escape from _PyUnicode_DecodeUnicodeEscapeInternal().

_PyBytes_DecodeEscape() does not have such issue, because it does not
use the error handlers registry, but it should be changed for compatibility
with _PyUnicode_DecodeUnicodeEscapeInternal().
(cherry picked from commit 9f69a58623)
(cherry picked from commit 6279eb8c07)
This commit is contained in:
Serhiy Storchaka 2025-05-26 06:33:22 +03:00 committed by GitHub
parent 310cd8943a
commit 4398b788ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 194 additions and 57 deletions

View file

@ -25,6 +25,10 @@ PyAPI_FUNC(PyObject*) _PyBytes_FromHex(
int use_bytearray);
/* Helper for PyBytes_DecodeEscape that detects invalid escape chars. */
PyAPI_FUNC(PyObject*) _PyBytes_DecodeEscape2(const char *, Py_ssize_t,
const char *,
int *, const char **);
// Export for binary compatibility.
PyAPI_FUNC(PyObject *) _PyBytes_DecodeEscape(const char *, Py_ssize_t,
const char *, const char **);

View file

@ -684,6 +684,19 @@ PyAPI_FUNC(PyObject*) _PyUnicode_DecodeUnicodeEscapeStateful(
);
/* Helper for PyUnicode_DecodeUnicodeEscape that detects invalid escape
chars. */
PyAPI_FUNC(PyObject*) _PyUnicode_DecodeUnicodeEscapeInternal2(
const char *string, /* Unicode-Escape encoded string */
Py_ssize_t length, /* size of string */
const char *errors, /* error handling */
Py_ssize_t *consumed, /* bytes consumed */
int *first_invalid_escape_char, /* on return, if not -1, contain the first
invalid escaped char (<= 0xff) or invalid
octal escape (> 0xff) in string. */
const char **first_invalid_escape_ptr); /* on return, if not NULL, may
point to the first invalid escaped
char in string.
May be NULL if errors is not NULL. */
// Export for binary compatibility.
PyAPI_FUNC(PyObject*) _PyUnicode_DecodeUnicodeEscapeInternal(
const char *string, /* Unicode-Escape encoded string */
Py_ssize_t length, /* size of string */

View file

@ -1,6 +1,7 @@
import codecs
import html.entities
import itertools
import re
import sys
import unicodedata
import unittest
@ -1124,7 +1125,7 @@ class CodecCallbackTest(unittest.TestCase):
text = 'abc<def>ghi'*n
text.translate(charmap)
def test_mutatingdecodehandler(self):
def test_mutating_decode_handler(self):
baddata = [
("ascii", b"\xff"),
("utf-7", b"++"),
@ -1159,6 +1160,42 @@ class CodecCallbackTest(unittest.TestCase):
for (encoding, data) in baddata:
self.assertEqual(data.decode(encoding, "test.mutating"), "\u4242")
def test_mutating_decode_handler_unicode_escape(self):
decode = codecs.unicode_escape_decode
def mutating(exc):
if isinstance(exc, UnicodeDecodeError):
r = data.get(exc.object[:exc.end])
if r is not None:
exc.object = r[0] + exc.object[exc.end:]
return ('\u0404', r[1])
raise AssertionError("don't know how to handle %r" % exc)
codecs.register_error('test.mutating2', mutating)
data = {
br'\x0': (b'\\', 0),
br'\x3': (b'xxx\\', 3),
br'\x5': (b'x\\', 1),
}
def check(input, expected, msg):
with self.assertWarns(DeprecationWarning) as cm:
self.assertEqual(decode(input, 'test.mutating2'), (expected, len(input)))
self.assertIn(msg, str(cm.warning))
check(br'\x0n\z', '\u0404\n\\z', r"invalid escape sequence '\z'")
check(br'\x0n\501', '\u0404\n\u0141', r"invalid octal escape sequence '\501'")
check(br'\x0z', '\u0404\\z', r"invalid escape sequence '\z'")
check(br'\x3n\zr', '\u0404\n\\zr', r"invalid escape sequence '\z'")
check(br'\x3zr', '\u0404\\zr', r"invalid escape sequence '\z'")
check(br'\x3z5', '\u0404\\z5', r"invalid escape sequence '\z'")
check(memoryview(br'\x3z5x')[:-1], '\u0404\\z5', r"invalid escape sequence '\z'")
check(memoryview(br'\x3z5xy')[:-2], '\u0404\\z5', r"invalid escape sequence '\z'")
check(br'\x5n\z', '\u0404\n\\z', r"invalid escape sequence '\z'")
check(br'\x5n\501', '\u0404\n\u0141', r"invalid octal escape sequence '\501'")
check(br'\x5z', '\u0404\\z', r"invalid escape sequence '\z'")
check(memoryview(br'\x5zy')[:-1], '\u0404\\z', r"invalid escape sequence '\z'")
# issue32583
def test_crashing_decode_handler(self):
# better generating one more character to fill the extra space slot

View file

@ -1196,23 +1196,39 @@ class EscapeDecodeTest(unittest.TestCase):
check(br"[\1010]", b"[A0]")
check(br"[\x41]", b"[A]")
check(br"[\x410]", b"[A0]")
def test_warnings(self):
decode = codecs.escape_decode
check = coding_checker(self, decode)
for i in range(97, 123):
b = bytes([i])
if b not in b'abfnrtvx':
with self.assertWarns(DeprecationWarning):
with self.assertWarnsRegex(DeprecationWarning,
r"invalid escape sequence '\\%c'" % i):
check(b"\\" + b, b"\\" + b)
with self.assertWarns(DeprecationWarning):
with self.assertWarnsRegex(DeprecationWarning,
r"invalid escape sequence '\\%c'" % (i-32)):
check(b"\\" + b.upper(), b"\\" + b.upper())
with self.assertWarns(DeprecationWarning):
with self.assertWarnsRegex(DeprecationWarning,
r"invalid escape sequence '\\8'"):
check(br"\8", b"\\8")
with self.assertWarns(DeprecationWarning):
check(br"\9", b"\\9")
with self.assertWarns(DeprecationWarning):
with self.assertWarnsRegex(DeprecationWarning,
r"invalid escape sequence '\\\xfa'") as cm:
check(b"\\\xfa", b"\\\xfa")
for i in range(0o400, 0o1000):
with self.assertWarns(DeprecationWarning):
with self.assertWarnsRegex(DeprecationWarning,
r"invalid octal escape sequence '\\%o'" % i):
check(rb'\%o' % i, bytes([i & 0o377]))
with self.assertWarnsRegex(DeprecationWarning,
r"invalid escape sequence '\\z'"):
self.assertEqual(decode(br'\x\z', 'ignore'), (b'\\z', 4))
with self.assertWarnsRegex(DeprecationWarning,
r"invalid octal escape sequence '\\501'"):
self.assertEqual(decode(br'\x\501', 'ignore'), (b'A', 6))
def test_errors(self):
decode = codecs.escape_decode
self.assertRaises(ValueError, decode, br"\x")
@ -2479,24 +2495,40 @@ class UnicodeEscapeTest(ReadTest, unittest.TestCase):
check(br"[\x410]", "[A0]")
check(br"\u20ac", "\u20ac")
check(br"\U0001d120", "\U0001d120")
def test_decode_warnings(self):
decode = codecs.unicode_escape_decode
check = coding_checker(self, decode)
for i in range(97, 123):
b = bytes([i])
if b not in b'abfnrtuvx':
with self.assertWarns(DeprecationWarning):
with self.assertWarnsRegex(DeprecationWarning,
r"invalid escape sequence '\\%c'" % i):
check(b"\\" + b, "\\" + chr(i))
if b.upper() not in b'UN':
with self.assertWarns(DeprecationWarning):
with self.assertWarnsRegex(DeprecationWarning,
r"invalid escape sequence '\\%c'" % (i-32)):
check(b"\\" + b.upper(), "\\" + chr(i-32))
with self.assertWarns(DeprecationWarning):
with self.assertWarnsRegex(DeprecationWarning,
r"invalid escape sequence '\\8'"):
check(br"\8", "\\8")
with self.assertWarns(DeprecationWarning):
check(br"\9", "\\9")
with self.assertWarns(DeprecationWarning):
with self.assertWarnsRegex(DeprecationWarning,
r"invalid escape sequence '\\\xfa'") as cm:
check(b"\\\xfa", "\\\xfa")
for i in range(0o400, 0o1000):
with self.assertWarns(DeprecationWarning):
with self.assertWarnsRegex(DeprecationWarning,
r"invalid octal escape sequence '\\%o'" % i):
check(rb'\%o' % i, chr(i))
with self.assertWarnsRegex(DeprecationWarning,
r"invalid escape sequence '\\z'"):
self.assertEqual(decode(br'\x\z', 'ignore'), ('\\z', 4))
with self.assertWarnsRegex(DeprecationWarning,
r"invalid octal escape sequence '\\501'"):
self.assertEqual(decode(br'\x\501', 'ignore'), ('\u0141', 6))
def test_decode_errors(self):
decode = codecs.unicode_escape_decode
for c, d in (b'x', 2), (b'u', 4), (b'U', 4):

View file

@ -0,0 +1,2 @@
Fix use-after-free in the "unicode-escape" decoder with a non-"strict" error
handler.

View file

@ -1048,10 +1048,11 @@ _PyBytes_FormatEx(const char *format, Py_ssize_t format_len,
}
/* Unescape a backslash-escaped string. */
PyObject *_PyBytes_DecodeEscape(const char *s,
PyObject *_PyBytes_DecodeEscape2(const char *s,
Py_ssize_t len,
const char *errors,
const char **first_invalid_escape)
int *first_invalid_escape_char,
const char **first_invalid_escape_ptr)
{
int c;
char *p;
@ -1065,7 +1066,8 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
return NULL;
writer.overallocate = 1;
*first_invalid_escape = NULL;
*first_invalid_escape_char = -1;
*first_invalid_escape_ptr = NULL;
end = s + len;
while (s < end) {
@ -1103,9 +1105,10 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
c = (c<<3) + *s++ - '0';
}
if (c > 0377) {
if (*first_invalid_escape == NULL) {
*first_invalid_escape = s-3; /* Back up 3 chars, since we've
already incremented s. */
if (*first_invalid_escape_char == -1) {
*first_invalid_escape_char = c;
/* Back up 3 chars, since we've already incremented s. */
*first_invalid_escape_ptr = s - 3;
}
}
*p++ = c;
@ -1146,9 +1149,10 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
break;
default:
if (*first_invalid_escape == NULL) {
*first_invalid_escape = s-1; /* Back up one char, since we've
already incremented s. */
if (*first_invalid_escape_char == -1) {
*first_invalid_escape_char = (unsigned char)s[-1];
/* Back up one char, since we've already incremented s. */
*first_invalid_escape_ptr = s - 1;
}
*p++ = '\\';
s--;
@ -1162,23 +1166,37 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
return NULL;
}
// Export for binary compatibility.
PyObject *_PyBytes_DecodeEscape(const char *s,
Py_ssize_t len,
const char *errors,
const char **first_invalid_escape)
{
int first_invalid_escape_char;
return _PyBytes_DecodeEscape2(
s, len, errors,
&first_invalid_escape_char,
first_invalid_escape);
}
PyObject *PyBytes_DecodeEscape(const char *s,
Py_ssize_t len,
const char *errors,
Py_ssize_t Py_UNUSED(unicode),
const char *Py_UNUSED(recode_encoding))
{
const char* first_invalid_escape;
PyObject *result = _PyBytes_DecodeEscape(s, len, errors,
&first_invalid_escape);
int first_invalid_escape_char;
const char *first_invalid_escape_ptr;
PyObject *result = _PyBytes_DecodeEscape2(s, len, errors,
&first_invalid_escape_char,
&first_invalid_escape_ptr);
if (result == NULL)
return NULL;
if (first_invalid_escape != NULL) {
unsigned char c = *first_invalid_escape;
if ('4' <= c && c <= '7') {
if (first_invalid_escape_char != -1) {
if (first_invalid_escape_char > 0xff) {
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
"invalid octal escape sequence '\\%.3s'",
first_invalid_escape) < 0)
"invalid octal escape sequence '\\%o'",
first_invalid_escape_char) < 0)
{
Py_DECREF(result);
return NULL;
@ -1187,7 +1205,7 @@ PyObject *PyBytes_DecodeEscape(const char *s,
else {
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
"invalid escape sequence '\\%c'",
c) < 0)
first_invalid_escape_char) < 0)
{
Py_DECREF(result);
return NULL;

View file

@ -6046,13 +6046,15 @@ PyUnicode_AsUTF16String(PyObject *unicode)
/* --- Unicode Escape Codec ----------------------------------------------- */
PyObject *
_PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
_PyUnicode_DecodeUnicodeEscapeInternal2(const char *s,
Py_ssize_t size,
const char *errors,
Py_ssize_t *consumed,
const char **first_invalid_escape)
int *first_invalid_escape_char,
const char **first_invalid_escape_ptr)
{
const char *starts = s;
const char *initial_starts = starts;
_PyUnicodeWriter writer;
const char *end;
PyObject *errorHandler = NULL;
@ -6061,7 +6063,8 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
PyInterpreterState *interp = _PyInterpreterState_Get();
// so we can remember if we've seen an invalid escape char or not
*first_invalid_escape = NULL;
*first_invalid_escape_char = -1;
*first_invalid_escape_ptr = NULL;
if (size == 0) {
if (consumed) {
@ -6149,9 +6152,12 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
}
}
if (ch > 0377) {
if (*first_invalid_escape == NULL) {
*first_invalid_escape = s-3; /* Back up 3 chars, since we've
already incremented s. */
if (*first_invalid_escape_char == -1) {
*first_invalid_escape_char = ch;
if (starts == initial_starts) {
/* Back up 3 chars, since we've already incremented s. */
*first_invalid_escape_ptr = s - 3;
}
}
}
WRITE_CHAR(ch);
@ -6252,9 +6258,12 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
goto error;
default:
if (*first_invalid_escape == NULL) {
*first_invalid_escape = s-1; /* Back up one char, since we've
already incremented s. */
if (*first_invalid_escape_char == -1) {
*first_invalid_escape_char = c;
if (starts == initial_starts) {
/* Back up one char, since we've already incremented s. */
*first_invalid_escape_ptr = s - 1;
}
}
WRITE_ASCII_CHAR('\\');
WRITE_CHAR(c);
@ -6293,24 +6302,40 @@ _PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
return NULL;
}
// Export for binary compatibility.
PyObject *
_PyUnicode_DecodeUnicodeEscapeInternal(const char *s,
Py_ssize_t size,
const char *errors,
Py_ssize_t *consumed,
const char **first_invalid_escape)
{
int first_invalid_escape_char;
return _PyUnicode_DecodeUnicodeEscapeInternal2(
s, size, errors, consumed,
&first_invalid_escape_char,
first_invalid_escape);
}
PyObject *
_PyUnicode_DecodeUnicodeEscapeStateful(const char *s,
Py_ssize_t size,
const char *errors,
Py_ssize_t *consumed)
{
const char *first_invalid_escape;
PyObject *result = _PyUnicode_DecodeUnicodeEscapeInternal(s, size, errors,
int first_invalid_escape_char;
const char *first_invalid_escape_ptr;
PyObject *result = _PyUnicode_DecodeUnicodeEscapeInternal2(s, size, errors,
consumed,
&first_invalid_escape);
&first_invalid_escape_char,
&first_invalid_escape_ptr);
if (result == NULL)
return NULL;
if (first_invalid_escape != NULL) {
unsigned char c = *first_invalid_escape;
if ('4' <= c && c <= '7') {
if (first_invalid_escape_char != -1) {
if (first_invalid_escape_char > 0xff) {
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
"invalid octal escape sequence '\\%.3s'",
first_invalid_escape) < 0)
"invalid octal escape sequence '\\%o'",
first_invalid_escape_char) < 0)
{
Py_DECREF(result);
return NULL;
@ -6319,7 +6344,7 @@ _PyUnicode_DecodeUnicodeEscapeStateful(const char *s,
else {
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
"invalid escape sequence '\\%c'",
c) < 0)
first_invalid_escape_char) < 0)
{
Py_DECREF(result);
return NULL;

View file

@ -181,15 +181,18 @@ decode_unicode_with_escapes(Parser *parser, const char *s, size_t len, Token *t)
len = p - buf;
s = buf;
const char *first_invalid_escape;
v = _PyUnicode_DecodeUnicodeEscapeInternal(s, len, NULL, NULL, &first_invalid_escape);
int first_invalid_escape_char;
const char *first_invalid_escape_ptr;
v = _PyUnicode_DecodeUnicodeEscapeInternal2(s, (Py_ssize_t)len, NULL, NULL,
&first_invalid_escape_char,
&first_invalid_escape_ptr);
// HACK: later we can simply pass the line no, since we don't preserve the tokens
// when we are decoding the string but we preserve the line numbers.
if (v != NULL && first_invalid_escape != NULL && t != NULL) {
if (warn_invalid_escape_sequence(parser, s, first_invalid_escape, t) < 0) {
/* We have not decref u before because first_invalid_escape points
inside u. */
if (v != NULL && first_invalid_escape_ptr != NULL && t != NULL) {
if (warn_invalid_escape_sequence(parser, s, first_invalid_escape_ptr, t) < 0) {
/* We have not decref u before because first_invalid_escape_ptr
points inside u. */
Py_XDECREF(u);
Py_DECREF(v);
return NULL;
@ -202,14 +205,17 @@ decode_unicode_with_escapes(Parser *parser, const char *s, size_t len, Token *t)
static PyObject *
decode_bytes_with_escapes(Parser *p, const char *s, Py_ssize_t len, Token *t)
{
const char *first_invalid_escape;
PyObject *result = _PyBytes_DecodeEscape(s, len, NULL, &first_invalid_escape);
int first_invalid_escape_char;
const char *first_invalid_escape_ptr;
PyObject *result = _PyBytes_DecodeEscape2(s, len, NULL,
&first_invalid_escape_char,
&first_invalid_escape_ptr);
if (result == NULL) {
return NULL;
}
if (first_invalid_escape != NULL) {
if (warn_invalid_escape_sequence(p, s, first_invalid_escape, t) < 0) {
if (first_invalid_escape_ptr != NULL) {
if (warn_invalid_escape_sequence(p, s, first_invalid_escape_ptr, t) < 0) {
Py_DECREF(result);
return NULL;
}