gh-112006: Fix inspect.unwrap() for types where __wrapped__ is a data descriptor (GH-115540)

This also fixes inspect.Signature.from_callable() for builtins classmethod()
and staticmethod().
This commit is contained in:
Serhiy Storchaka 2024-02-26 20:07:41 +02:00 committed by GitHub
parent b05afdd5ec
commit 68c79d21fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 32 additions and 13 deletions

View file

@ -762,18 +762,14 @@ def unwrap(func, *, stop=None):
:exc:`ValueError` is raised if a cycle is encountered. :exc:`ValueError` is raised if a cycle is encountered.
""" """
if stop is None:
def _is_wrapper(f):
return hasattr(f, '__wrapped__')
else:
def _is_wrapper(f):
return hasattr(f, '__wrapped__') and not stop(f)
f = func # remember the original func for error reporting f = func # remember the original func for error reporting
# Memoise by id to tolerate non-hashable objects, but store objects to # Memoise by id to tolerate non-hashable objects, but store objects to
# ensure they aren't destroyed, which would allow their IDs to be reused. # ensure they aren't destroyed, which would allow their IDs to be reused.
memo = {id(f): f} memo = {id(f): f}
recursion_limit = sys.getrecursionlimit() recursion_limit = sys.getrecursionlimit()
while _is_wrapper(func): while not isinstance(func, type) and hasattr(func, '__wrapped__'):
if stop is not None and stop(func):
break
func = func.__wrapped__ func = func.__wrapped__
id_func = id(func) id_func = id(func)
if (id_func in memo) or (len(memo) >= recursion_limit): if (id_func in memo) or (len(memo) >= recursion_limit):

View file

@ -3137,6 +3137,10 @@ class TestSignatureObject(unittest.TestCase):
int)) int))
def test_signature_on_classmethod(self): def test_signature_on_classmethod(self):
self.assertEqual(self.signature(classmethod),
((('function', ..., ..., "positional_only"),),
...))
class Test: class Test:
@classmethod @classmethod
def foo(cls, arg1, *, arg2=1): def foo(cls, arg1, *, arg2=1):
@ -3155,6 +3159,10 @@ class TestSignatureObject(unittest.TestCase):
...)) ...))
def test_signature_on_staticmethod(self): def test_signature_on_staticmethod(self):
self.assertEqual(self.signature(staticmethod),
((('function', ..., ..., "positional_only"),),
...))
class Test: class Test:
@staticmethod @staticmethod
def foo(cls, *, arg): def foo(cls, *, arg):
@ -3678,16 +3686,20 @@ class TestSignatureObject(unittest.TestCase):
((('a', ..., ..., "positional_or_keyword"),), ((('a', ..., ..., "positional_or_keyword"),),
...)) ...))
class Wrapped: def test_signature_on_wrapper(self):
pass class Wrapper:
Wrapped.__wrapped__ = lambda a: None def __call__(self, b):
self.assertEqual(self.signature(Wrapped), pass
wrapper = Wrapper()
wrapper.__wrapped__ = lambda a: None
self.assertEqual(self.signature(wrapper),
((('a', ..., ..., "positional_or_keyword"),), ((('a', ..., ..., "positional_or_keyword"),),
...)) ...))
# wrapper loop: # wrapper loop:
Wrapped.__wrapped__ = Wrapped wrapper = Wrapper()
wrapper.__wrapped__ = wrapper
with self.assertRaisesRegex(ValueError, 'wrapper loop'): with self.assertRaisesRegex(ValueError, 'wrapper loop'):
self.signature(Wrapped) self.signature(wrapper)
def test_signature_on_lambdas(self): def test_signature_on_lambdas(self):
self.assertEqual(self.signature((lambda a=10: a)), self.assertEqual(self.signature((lambda a=10: a)),
@ -4999,6 +5011,14 @@ class TestUnwrap(unittest.TestCase):
with self.assertRaisesRegex(ValueError, 'wrapper loop'): with self.assertRaisesRegex(ValueError, 'wrapper loop'):
inspect.unwrap(obj) inspect.unwrap(obj)
def test_wrapped_descriptor(self):
self.assertIs(inspect.unwrap(NTimesUnwrappable), NTimesUnwrappable)
self.assertIs(inspect.unwrap(staticmethod), staticmethod)
self.assertIs(inspect.unwrap(classmethod), classmethod)
self.assertIs(inspect.unwrap(staticmethod(classmethod)), classmethod)
self.assertIs(inspect.unwrap(classmethod(staticmethod)), staticmethod)
class TestMain(unittest.TestCase): class TestMain(unittest.TestCase):
def test_only_source(self): def test_only_source(self):
module = importlib.import_module('unittest') module = importlib.import_module('unittest')

View file

@ -0,0 +1,3 @@
Fix :func:`inspect.unwrap` for types with the ``__wrapper__`` data
descriptor. Fix :meth:`inspect.Signature.from_callable` for builtins
:func:`classmethod` and :func:`staticmethod`.