gh-130870: Preserve GenericAlias subclasses in typing.get_type_hints() (#131583)

This commit is contained in:
Victorien 2025-07-05 15:55:39 +02:00 committed by GitHub
parent f0c7344a8f
commit 5b56daa9d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 44 additions and 12 deletions

View file

@ -1605,7 +1605,10 @@ class TypeVarTupleTests(BaseTestCase):
self.assertEqual(gth(func1), {'args': Unpack[Ts]})
def func2(*args: *tuple[int, str]): pass
self.assertEqual(gth(func2), {'args': Unpack[tuple[int, str]]})
hint = gth(func2)['args']
self.assertIsInstance(hint, types.GenericAlias)
self.assertEqual(hint.__args__[0], int)
self.assertIs(hint.__unpacked__, True)
class CustomVariadic(Generic[*Ts]): pass
@ -1620,7 +1623,10 @@ class TypeVarTupleTests(BaseTestCase):
{'args': Unpack[Ts]})
def func2(*args: '*tuple[int, str]'): pass
self.assertEqual(gth(func2), {'args': Unpack[tuple[int, str]]})
hint = gth(func2)['args']
self.assertIsInstance(hint, types.GenericAlias)
self.assertEqual(hint.__args__[0], int)
self.assertIs(hint.__unpacked__, True)
class CustomVariadic(Generic[*Ts]): pass
@ -7114,6 +7120,24 @@ class GetTypeHintsTests(BaseTestCase):
right_hints = get_type_hints(t.add_right, globals(), locals())
self.assertEqual(right_hints['node'], Node[T])
def test_get_type_hints_preserve_generic_alias_subclasses(self):
# https://github.com/python/cpython/issues/130870
# A real world example of this is `collections.abc.Callable`. When parameterized,
# the result is a subclass of `types.GenericAlias`.
class MyAlias(types.GenericAlias):
pass
class MyClass:
def __class_getitem__(cls, args):
return MyAlias(cls, args)
# Using a forward reference is important, otherwise it works as expected.
# `y` tests that the `GenericAlias` subclass is preserved when stripping `Annotated`.
def func(x: MyClass['int'], y: MyClass[Annotated[int, ...]]): ...
assert isinstance(get_type_hints(func)['x'], MyAlias)
assert isinstance(get_type_hints(func)['y'], MyAlias)
class GetUtilitiesTestCase(TestCase):
def test_get_origin(self):