gh-89263: Add typing.get_overloads (GH-31716)

Based on suggestions by Guido van Rossum, Spencer Brown, and Alex Waygood.

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Guido van Rossum <gvanrossum@gmail.com>
Co-authored-by: Ken Jin <kenjin4096@gmail.com>
This commit is contained in:
Jelle Zijlstra 2022-04-16 09:01:43 -07:00 committed by GitHub
parent 9300b6d729
commit 055760ed9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 4 deletions

View file

@ -2407,6 +2407,35 @@ Functions and decorators
See :pep:`484` for details and comparison with other typing semantics. See :pep:`484` for details and comparison with other typing semantics.
.. versionchanged:: 3.11
Overloaded functions can now be introspected at runtime using
:func:`get_overloads`.
.. function:: get_overloads(func)
Return a sequence of :func:`@overload <overload>`-decorated definitions for
*func*. *func* is the function object for the implementation of the
overloaded function. For example, given the definition of ``process`` in
the documentation for :func:`@overload <overload>`,
``get_overloads(process)`` will return a sequence of three function objects
for the three defined overloads. If called on a function with no overloads,
``get_overloads`` returns an empty sequence.
``get_overloads`` can be used for introspecting an overloaded function at
runtime.
.. versionadded:: 3.11
.. function:: clear_overloads()
Clear all registered overloads in the internal registry. This can be used
to reclaim the memory used by the registry.
.. versionadded:: 3.11
.. decorator:: final .. decorator:: final
A decorator to indicate to type checkers that the decorated method A decorator to indicate to type checkers that the decorated method

View file

@ -1,5 +1,6 @@
import contextlib import contextlib
import collections import collections
from collections import defaultdict
from functools import lru_cache from functools import lru_cache
import inspect import inspect
import pickle import pickle
@ -7,9 +8,11 @@ import re
import sys import sys
import warnings import warnings
from unittest import TestCase, main, skipUnless, skip from unittest import TestCase, main, skipUnless, skip
from unittest.mock import patch
from copy import copy, deepcopy from copy import copy, deepcopy
from typing import Any, NoReturn, Never, assert_never from typing import Any, NoReturn, Never, assert_never
from typing import overload, get_overloads, clear_overloads
from typing import TypeVar, TypeVarTuple, Unpack, AnyStr from typing import TypeVar, TypeVarTuple, Unpack, AnyStr
from typing import T, KT, VT # Not in __all__. from typing import T, KT, VT # Not in __all__.
from typing import Union, Optional, Literal from typing import Union, Optional, Literal
@ -3890,11 +3893,22 @@ class ForwardRefTests(BaseTestCase):
self.assertEqual("x" | X, Union["x", X]) self.assertEqual("x" | X, Union["x", X])
@lru_cache()
def cached_func(x, y):
return 3 * x + y
class MethodHolder:
@classmethod
def clsmethod(cls): ...
@staticmethod
def stmethod(): ...
def method(self): ...
class OverloadTests(BaseTestCase): class OverloadTests(BaseTestCase):
def test_overload_fails(self): def test_overload_fails(self):
from typing import overload
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
@overload @overload
@ -3904,8 +3918,6 @@ class OverloadTests(BaseTestCase):
blah() blah()
def test_overload_succeeds(self): def test_overload_succeeds(self):
from typing import overload
@overload @overload
def blah(): def blah():
pass pass
@ -3915,6 +3927,58 @@ class OverloadTests(BaseTestCase):
blah() blah()
def set_up_overloads(self):
def blah():
pass
overload1 = blah
overload(blah)
def blah():
pass
overload2 = blah
overload(blah)
def blah():
pass
return blah, [overload1, overload2]
# Make sure we don't clear the global overload registry
@patch("typing._overload_registry",
defaultdict(lambda: defaultdict(dict)))
def test_overload_registry(self):
# The registry starts out empty
self.assertEqual(typing._overload_registry, {})
impl, overloads = self.set_up_overloads()
self.assertNotEqual(typing._overload_registry, {})
self.assertEqual(list(get_overloads(impl)), overloads)
def some_other_func(): pass
overload(some_other_func)
other_overload = some_other_func
def some_other_func(): pass
self.assertEqual(list(get_overloads(some_other_func)), [other_overload])
# Make sure that after we clear all overloads, the registry is
# completely empty.
clear_overloads()
self.assertEqual(typing._overload_registry, {})
self.assertEqual(get_overloads(impl), [])
# Querying a function with no overloads shouldn't change the registry.
def the_only_one(): pass
self.assertEqual(get_overloads(the_only_one), [])
self.assertEqual(typing._overload_registry, {})
def test_overload_registry_repeated(self):
for _ in range(2):
impl, overloads = self.set_up_overloads()
self.assertEqual(list(get_overloads(impl)), overloads)
# Definitions needed for features introduced in Python 3.6 # Definitions needed for features introduced in Python 3.6

View file

@ -21,6 +21,7 @@ At large scale, the structure of the module is following:
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import collections import collections
from collections import defaultdict
import collections.abc import collections.abc
import contextlib import contextlib
import functools import functools
@ -121,9 +122,11 @@ __all__ = [
'assert_type', 'assert_type',
'assert_never', 'assert_never',
'cast', 'cast',
'clear_overloads',
'final', 'final',
'get_args', 'get_args',
'get_origin', 'get_origin',
'get_overloads',
'get_type_hints', 'get_type_hints',
'is_typeddict', 'is_typeddict',
'LiteralString', 'LiteralString',
@ -2450,6 +2453,10 @@ def _overload_dummy(*args, **kwds):
"by an implementation that is not @overload-ed.") "by an implementation that is not @overload-ed.")
# {module: {qualname: {firstlineno: func}}}
_overload_registry = defaultdict(functools.partial(defaultdict, dict))
def overload(func): def overload(func):
"""Decorator for overloaded functions/methods. """Decorator for overloaded functions/methods.
@ -2475,10 +2482,37 @@ def overload(func):
def utf8(value: str) -> bytes: ... def utf8(value: str) -> bytes: ...
def utf8(value): def utf8(value):
# implementation goes here # implementation goes here
The overloads for a function can be retrieved at runtime using the
get_overloads() function.
""" """
# classmethod and staticmethod
f = getattr(func, "__func__", func)
try:
_overload_registry[f.__module__][f.__qualname__][f.__code__.co_firstlineno] = func
except AttributeError:
# Not a normal function; ignore.
pass
return _overload_dummy return _overload_dummy
def get_overloads(func):
"""Return all defined overloads for *func* as a sequence."""
# classmethod and staticmethod
f = getattr(func, "__func__", func)
if f.__module__ not in _overload_registry:
return []
mod_dict = _overload_registry[f.__module__]
if f.__qualname__ not in mod_dict:
return []
return list(mod_dict[f.__qualname__].values())
def clear_overloads():
"""Clear all overloads in the registry."""
_overload_registry.clear()
def final(f): def final(f):
"""A decorator to indicate final methods and final classes. """A decorator to indicate final methods and final classes.

View file

@ -0,0 +1,2 @@
Add :func:`typing.get_overloads` and :func:`typing.clear_overloads`.
Patch by Jelle Zijlstra.