gh-89263: Add typing.get_overloads (GH-31716)

Based on suggestions by Guido van Rossum, Spencer Brown, and Alex Waygood.

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Guido van Rossum <gvanrossum@gmail.com>
Co-authored-by: Ken Jin <kenjin4096@gmail.com>
This commit is contained in:
Jelle Zijlstra 2022-04-16 09:01:43 -07:00 committed by GitHub
parent 9300b6d729
commit 055760ed9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 4 deletions

View file

@ -1,5 +1,6 @@
import contextlib
import collections
from collections import defaultdict
from functools import lru_cache
import inspect
import pickle
@ -7,9 +8,11 @@ import re
import sys
import warnings
from unittest import TestCase, main, skipUnless, skip
from unittest.mock import patch
from copy import copy, deepcopy
from typing import Any, NoReturn, Never, assert_never
from typing import overload, get_overloads, clear_overloads
from typing import TypeVar, TypeVarTuple, Unpack, AnyStr
from typing import T, KT, VT # Not in __all__.
from typing import Union, Optional, Literal
@ -3890,11 +3893,22 @@ class ForwardRefTests(BaseTestCase):
self.assertEqual("x" | X, Union["x", X])
@lru_cache()
def cached_func(x, y):
return 3 * x + y
class MethodHolder:
@classmethod
def clsmethod(cls): ...
@staticmethod
def stmethod(): ...
def method(self): ...
class OverloadTests(BaseTestCase):
def test_overload_fails(self):
from typing import overload
with self.assertRaises(RuntimeError):
@overload
@ -3904,8 +3918,6 @@ class OverloadTests(BaseTestCase):
blah()
def test_overload_succeeds(self):
from typing import overload
@overload
def blah():
pass
@ -3915,6 +3927,58 @@ class OverloadTests(BaseTestCase):
blah()
def set_up_overloads(self):
def blah():
pass
overload1 = blah
overload(blah)
def blah():
pass
overload2 = blah
overload(blah)
def blah():
pass
return blah, [overload1, overload2]
# Make sure we don't clear the global overload registry
@patch("typing._overload_registry",
defaultdict(lambda: defaultdict(dict)))
def test_overload_registry(self):
# The registry starts out empty
self.assertEqual(typing._overload_registry, {})
impl, overloads = self.set_up_overloads()
self.assertNotEqual(typing._overload_registry, {})
self.assertEqual(list(get_overloads(impl)), overloads)
def some_other_func(): pass
overload(some_other_func)
other_overload = some_other_func
def some_other_func(): pass
self.assertEqual(list(get_overloads(some_other_func)), [other_overload])
# Make sure that after we clear all overloads, the registry is
# completely empty.
clear_overloads()
self.assertEqual(typing._overload_registry, {})
self.assertEqual(get_overloads(impl), [])
# Querying a function with no overloads shouldn't change the registry.
def the_only_one(): pass
self.assertEqual(get_overloads(the_only_one), [])
self.assertEqual(typing._overload_registry, {})
def test_overload_registry_repeated(self):
for _ in range(2):
impl, overloads = self.set_up_overloads()
self.assertEqual(list(get_overloads(impl)), overloads)
# Definitions needed for features introduced in Python 3.6