From 8a9c6c4d16a746eea1e000d6701d1c274c1f331b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Sat, 19 Apr 2025 10:44:01 +0200 Subject: [PATCH] gh-128398: improve error messages when incorrectly using `with` and `async with` (#132218) Improve the error message with a suggestion when an object supporting the synchronous (resp. asynchronous) context manager protocol is entered using `async with` (resp. `with`) instead of `with` (resp. `async with`). --- Doc/whatsnew/3.14.rst | 6 + Include/internal/pycore_ceval.h | 11 ++ Lib/test/test_with.py | 105 +++++++++++++----- Lib/unittest/async_case.py | 14 ++- Lib/unittest/case.py | 13 ++- ...-04-07-13-46-57.gh-issue-128398.gJ2zIF.rst | 4 + Python/bytecodes.c | 9 +- Python/ceval.c | 74 ++++++++++-- Python/executor_cases.c.h | 11 +- Python/generated_cases.c.h | 11 +- 10 files changed, 211 insertions(+), 47 deletions(-) create mode 100644 Misc/NEWS.d/next/Core_and_Builtins/2025-04-07-13-46-57.gh-issue-128398.gJ2zIF.rst diff --git a/Doc/whatsnew/3.14.rst b/Doc/whatsnew/3.14.rst index aaa4702d53d..56858aee449 100644 --- a/Doc/whatsnew/3.14.rst +++ b/Doc/whatsnew/3.14.rst @@ -479,6 +479,12 @@ Other language changes :func:`textwrap.dedent`. (Contributed by Jon Crall and Steven Sun in :gh:`103998`.) +* Improve error message when an object supporting the synchronous (resp. + asynchronous) context manager protocol is entered using :keyword:`async + with` (resp. :keyword:`with`) instead of :keyword:`with` (resp. + :keyword:`async with`). + (Contributed by Bénédikt Tran in :gh:`128398`.) + .. _whatsnew314-pep765: diff --git a/Include/internal/pycore_ceval.h b/Include/internal/pycore_ceval.h index 18c8bc0624f..96ba54b274c 100644 --- a/Include/internal/pycore_ceval.h +++ b/Include/internal/pycore_ceval.h @@ -279,6 +279,7 @@ PyAPI_DATA(const conversion_func) _PyEval_ConversionFuncs[]; typedef struct _special_method { PyObject *name; const char *error; + const char *error_suggestion; // improved optional suggestion } _Py_SpecialMethod; PyAPI_DATA(const _Py_SpecialMethod) _Py_SpecialMethods[]; @@ -309,6 +310,16 @@ PyAPI_FUNC(PyObject *) _PyEval_LoadName(PyThreadState *tstate, _PyInterpreterFra PyAPI_FUNC(int) _Py_Check_ArgsIterable(PyThreadState *tstate, PyObject *func, PyObject *args); +/* + * Indicate whether a special method of given 'oparg' can use the (improved) + * alternative error message instead. Only methods loaded by LOAD_SPECIAL + * support alternative error messages. + * + * Symbol is exported for the JIT (see discussion on GH-132218). + */ +PyAPI_FUNC(int) +_PyEval_SpecialMethodCanSuggest(PyObject *self, int oparg); + /* Bits that can be set in PyThreadState.eval_breaker */ #define _PY_GIL_DROP_REQUEST_BIT (1U << 0) #define _PY_SIGNALS_PENDING_BIT (1U << 1) diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py index 1d2ce9eccc4..fd7abd1782e 100644 --- a/Lib/test/test_with.py +++ b/Lib/test/test_with.py @@ -1,9 +1,10 @@ -"""Unit tests for the with statement specified in PEP 343.""" +"""Unit tests for the 'with/async with' statements specified in PEP 343/492.""" __author__ = "Mike Bland" __email__ = "mbland at acm dot org" +import re import sys import traceback import unittest @@ -11,6 +12,16 @@ from collections import deque from contextlib import _GeneratorContextManager, contextmanager, nullcontext +def do_with(obj): + with obj: + pass + + +async def do_async_with(obj): + async with obj: + pass + + class MockContextManager(_GeneratorContextManager): def __init__(self, *args): super().__init__(*args) @@ -110,34 +121,77 @@ class FailureTestCase(unittest.TestCase): with foo: pass self.assertRaises(NameError, fooNotDeclared) - def testEnterAttributeError1(self): - class LacksEnter(object): - def __exit__(self, type, value, traceback): - pass + def testEnterAttributeError(self): + class LacksEnter: + def __exit__(self, type, value, traceback): ... - def fooLacksEnter(): - foo = LacksEnter() - with foo: pass - self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter) - - def testEnterAttributeError2(self): - class LacksEnterAndExit(object): - pass - - def fooLacksEnterAndExit(): - foo = LacksEnterAndExit() - with foo: pass - self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit) + with self.assertRaisesRegex(TypeError, re.escape(( + "object does not support the context manager protocol " + "(missed __enter__ method)" + ))): + do_with(LacksEnter()) def testExitAttributeError(self): - class LacksExit(object): - def __enter__(self): - pass + class LacksExit: + def __enter__(self): ... - def fooLacksExit(): - foo = LacksExit() - with foo: pass - self.assertRaisesRegex(TypeError, 'the context manager.*__exit__', fooLacksExit) + msg = re.escape(( + "object does not support the context manager protocol " + "(missed __exit__ method)" + )) + # a missing __exit__ is reported missing before a missing __enter__ + with self.assertRaisesRegex(TypeError, msg): + do_with(object()) + with self.assertRaisesRegex(TypeError, msg): + do_with(LacksExit()) + + def testWithForAsyncManager(self): + class AsyncManager: + async def __aenter__(self): ... + async def __aexit__(self, type, value, traceback): ... + + with self.assertRaisesRegex(TypeError, re.escape(( + "object does not support the context manager protocol " + "(missed __exit__ method) but it supports the asynchronous " + "context manager protocol. Did you mean to use 'async with'?" + ))): + do_with(AsyncManager()) + + def testAsyncEnterAttributeError(self): + class LacksAsyncEnter: + async def __aexit__(self, type, value, traceback): ... + + with self.assertRaisesRegex(TypeError, re.escape(( + "object does not support the asynchronous context manager protocol " + "(missed __aenter__ method)" + ))): + do_async_with(LacksAsyncEnter()).send(None) + + def testAsyncExitAttributeError(self): + class LacksAsyncExit: + async def __aenter__(self): ... + + msg = re.escape(( + "object does not support the asynchronous context manager protocol " + "(missed __aexit__ method)" + )) + # a missing __aexit__ is reported missing before a missing __aenter__ + with self.assertRaisesRegex(TypeError, msg): + do_async_with(object()).send(None) + with self.assertRaisesRegex(TypeError, msg): + do_async_with(LacksAsyncExit()).send(None) + + def testAsyncWithForSyncManager(self): + class SyncManager: + def __enter__(self): ... + def __exit__(self, type, value, traceback): ... + + with self.assertRaisesRegex(TypeError, re.escape(( + "object does not support the asynchronous context manager protocol " + "(missed __aexit__ method) but it supports the context manager " + "protocol. Did you mean to use 'with'?" + ))): + do_async_with(SyncManager()).send(None) def assertRaisesSyntaxError(self, codestr): def shouldRaiseSyntaxError(s): @@ -190,6 +244,7 @@ class FailureTestCase(unittest.TestCase): pass self.assertRaises(RuntimeError, shouldThrow) + class ContextmanagerAssertionMixin(object): def setUp(self): diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py index 6000af1cef0..a1c0d6c368c 100644 --- a/Lib/unittest/async_case.py +++ b/Lib/unittest/async_case.py @@ -75,9 +75,17 @@ class IsolatedAsyncioTestCase(TestCase): enter = cls.__aenter__ exit = cls.__aexit__ except AttributeError: - raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does " - f"not support the asynchronous context manager protocol" - ) from None + msg = (f"'{cls.__module__}.{cls.__qualname__}' object does " + "not support the asynchronous context manager protocol") + try: + cls.__enter__ + cls.__exit__ + except AttributeError: + pass + else: + msg += (" but it supports the context manager protocol. " + "Did you mean to use enterContext()?") + raise TypeError(msg) from None result = await enter(cm) self.addAsyncCleanup(exit, cm, None, None, None) return result diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 10c3b7e1223..884fc1b21f6 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -111,8 +111,17 @@ def _enter_context(cm, addcleanup): enter = cls.__enter__ exit = cls.__exit__ except AttributeError: - raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does " - f"not support the context manager protocol") from None + msg = (f"'{cls.__module__}.{cls.__qualname__}' object does " + "not support the context manager protocol") + try: + cls.__aenter__ + cls.__aexit__ + except AttributeError: + pass + else: + msg += (" but it supports the asynchronous context manager " + "protocol. Did you mean to use enterAsyncContext()?") + raise TypeError(msg) from None result = enter(cm) addcleanup(exit, cm, None, None, None) return result diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2025-04-07-13-46-57.gh-issue-128398.gJ2zIF.rst b/Misc/NEWS.d/next/Core_and_Builtins/2025-04-07-13-46-57.gh-issue-128398.gJ2zIF.rst new file mode 100644 index 00000000000..792332db6ef --- /dev/null +++ b/Misc/NEWS.d/next/Core_and_Builtins/2025-04-07-13-46-57.gh-issue-128398.gJ2zIF.rst @@ -0,0 +1,4 @@ +Improve error message when an object supporting the synchronous (resp. +asynchronous) context manager protocol is entered using :keyword:`async +with` (resp. :keyword:`with`) instead of :keyword:`with` (resp. +:keyword:`async with`). Patch by Bénédikt Tran. diff --git a/Python/bytecodes.c b/Python/bytecodes.c index 2796c3f2e85..07df22c761f 100644 --- a/Python/bytecodes.c +++ b/Python/bytecodes.c @@ -3425,9 +3425,12 @@ dummy_func( PyObject *attr_o = _PyObject_LookupSpecialMethod(owner_o, name, &self_or_null_o); if (attr_o == NULL) { if (!_PyErr_Occurred(tstate)) { - _PyErr_Format(tstate, PyExc_TypeError, - _Py_SpecialMethods[oparg].error, - Py_TYPE(owner_o)->tp_name); + const char *errfmt = _PyEval_SpecialMethodCanSuggest(owner_o, oparg) + ? _Py_SpecialMethods[oparg].error_suggestion + : _Py_SpecialMethods[oparg].error; + assert(!_PyErr_Occurred(tstate)); + assert(errfmt != NULL); + _PyErr_Format(tstate, PyExc_TypeError, errfmt, owner_o); } ERROR_IF(true, error); } diff --git a/Python/ceval.c b/Python/ceval.c index e534c7e2b88..17e28439872 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -545,23 +545,51 @@ const conversion_func _PyEval_ConversionFuncs[4] = { const _Py_SpecialMethod _Py_SpecialMethods[] = { [SPECIAL___ENTER__] = { .name = &_Py_ID(__enter__), - .error = "'%.200s' object does not support the " - "context manager protocol (missed __enter__ method)", + .error = ( + "'%T' object does not support the context manager protocol " + "(missed __enter__ method)" + ), + .error_suggestion = ( + "'%T' object does not support the context manager protocol " + "(missed __enter__ method) but it supports the asynchronous " + "context manager protocol. Did you mean to use 'async with'?" + ) }, [SPECIAL___EXIT__] = { .name = &_Py_ID(__exit__), - .error = "'%.200s' object does not support the " - "context manager protocol (missed __exit__ method)", + .error = ( + "'%T' object does not support the context manager protocol " + "(missed __exit__ method)" + ), + .error_suggestion = ( + "'%T' object does not support the context manager protocol " + "(missed __exit__ method) but it supports the asynchronous " + "context manager protocol. Did you mean to use 'async with'?" + ) }, [SPECIAL___AENTER__] = { .name = &_Py_ID(__aenter__), - .error = "'%.200s' object does not support the asynchronous " - "context manager protocol (missed __aenter__ method)", + .error = ( + "'%T' object does not support the asynchronous " + "context manager protocol (missed __aenter__ method)" + ), + .error_suggestion = ( + "'%T' object does not support the asynchronous context manager " + "protocol (missed __aenter__ method) but it supports the context " + "manager protocol. Did you mean to use 'with'?" + ) }, [SPECIAL___AEXIT__] = { .name = &_Py_ID(__aexit__), - .error = "'%.200s' object does not support the asynchronous " - "context manager protocol (missed __aexit__ method)", + .error = ( + "'%T' object does not support the asynchronous " + "context manager protocol (missed __aexit__ method)" + ), + .error_suggestion = ( + "'%T' object does not support the asynchronous context manager " + "protocol (missed __aexit__ method) but it supports the context " + "manager protocol. Did you mean to use 'with'?" + ) } }; @@ -3380,3 +3408,33 @@ _PyEval_LoadName(PyThreadState *tstate, _PyInterpreterFrame *frame, PyObject *na } return value; } + +/* Check if a 'cls' provides the given special method. */ +static inline int +type_has_special_method(PyTypeObject *cls, PyObject *name) +{ + // _PyType_Lookup() does not set an exception and returns a borrowed ref + assert(!PyErr_Occurred()); + PyObject *r = _PyType_Lookup(cls, name); + return r != NULL && Py_TYPE(r)->tp_descr_get != NULL; +} + +int +_PyEval_SpecialMethodCanSuggest(PyObject *self, int oparg) +{ + PyTypeObject *type = Py_TYPE(self); + switch (oparg) { + case SPECIAL___ENTER__: + case SPECIAL___EXIT__: { + return type_has_special_method(type, &_Py_ID(__aenter__)) + && type_has_special_method(type, &_Py_ID(__aexit__)); + } + case SPECIAL___AENTER__: + case SPECIAL___AEXIT__: { + return type_has_special_method(type, &_Py_ID(__enter__)) + && type_has_special_method(type, &_Py_ID(__exit__)); + } + default: + Py_FatalError("unsupported special method"); + } +} diff --git a/Python/executor_cases.c.h b/Python/executor_cases.c.h index 122285ba12e..cd265c383bd 100644 --- a/Python/executor_cases.c.h +++ b/Python/executor_cases.c.h @@ -4425,9 +4425,14 @@ if (attr_o == NULL) { if (!_PyErr_Occurred(tstate)) { _PyFrame_SetStackPointer(frame, stack_pointer); - _PyErr_Format(tstate, PyExc_TypeError, - _Py_SpecialMethods[oparg].error, - Py_TYPE(owner_o)->tp_name); + const char *errfmt = _PyEval_SpecialMethodCanSuggest(owner_o, oparg) + ? _Py_SpecialMethods[oparg].error_suggestion + : _Py_SpecialMethods[oparg].error; + stack_pointer = _PyFrame_GetStackPointer(frame); + assert(!_PyErr_Occurred(tstate)); + assert(errfmt != NULL); + _PyFrame_SetStackPointer(frame, stack_pointer); + _PyErr_Format(tstate, PyExc_TypeError, errfmt, owner_o); stack_pointer = _PyFrame_GetStackPointer(frame); } JUMP_TO_ERROR(); diff --git a/Python/generated_cases.c.h b/Python/generated_cases.c.h index cc85405f80b..911f5ae3e7c 100644 --- a/Python/generated_cases.c.h +++ b/Python/generated_cases.c.h @@ -9358,9 +9358,14 @@ if (attr_o == NULL) { if (!_PyErr_Occurred(tstate)) { _PyFrame_SetStackPointer(frame, stack_pointer); - _PyErr_Format(tstate, PyExc_TypeError, - _Py_SpecialMethods[oparg].error, - Py_TYPE(owner_o)->tp_name); + const char *errfmt = _PyEval_SpecialMethodCanSuggest(owner_o, oparg) + ? _Py_SpecialMethods[oparg].error_suggestion + : _Py_SpecialMethods[oparg].error; + stack_pointer = _PyFrame_GetStackPointer(frame); + assert(!_PyErr_Occurred(tstate)); + assert(errfmt != NULL); + _PyFrame_SetStackPointer(frame, stack_pointer); + _PyErr_Format(tstate, PyExc_TypeError, errfmt, owner_o); stack_pointer = _PyFrame_GetStackPointer(frame); } JUMP_TO_LABEL(error);