gh-119180: annotationlib: Fix __all__, formatting (#122365)

This commit is contained in:
Jelle Zijlstra 2024-08-11 16:44:51 -07:00 committed by GitHub
parent 016f4b5975
commit 4534068f22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 68 additions and 33 deletions

View file

@ -6,7 +6,14 @@ import functools
import sys import sys
import types import types
__all__ = ["Format", "ForwardRef", "call_annotate_function", "get_annotations"] __all__ = [
"Format",
"ForwardRef",
"call_annotate_function",
"call_evaluate_function",
"get_annotate_function",
"get_annotations",
]
class Format(enum.IntEnum): class Format(enum.IntEnum):
@ -426,8 +433,7 @@ def call_evaluate_function(evaluate, format, *, owner=None):
return call_annotate_function(evaluate, format, owner=owner, _is_evaluate=True) return call_annotate_function(evaluate, format, owner=owner, _is_evaluate=True)
def call_annotate_function(annotate, format, *, owner=None, def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
_is_evaluate=False):
"""Call an __annotate__ function. __annotate__ functions are normally """Call an __annotate__ function. __annotate__ functions are normally
generated by the compiler to defer the evaluation of annotations. They generated by the compiler to defer the evaluation of annotations. They
can be called with any of the format arguments in the Format enum, but can be called with any of the format arguments in the Format enum, but
@ -473,8 +479,13 @@ def call_annotate_function(annotate, format, *, owner=None,
closure = tuple(new_closure) closure = tuple(new_closure)
else: else:
closure = None closure = None
func = types.FunctionType(annotate.__code__, globals, closure=closure, func = types.FunctionType(
argdefs=annotate.__defaults__, kwdefaults=annotate.__kwdefaults__) annotate.__code__,
globals,
closure=closure,
argdefs=annotate.__defaults__,
kwdefaults=annotate.__kwdefaults__,
)
annos = func(Format.VALUE) annos = func(Format.VALUE)
if _is_evaluate: if _is_evaluate:
return annos if isinstance(annos, str) else repr(annos) return annos if isinstance(annos, str) else repr(annos)
@ -528,8 +539,13 @@ def call_annotate_function(annotate, format, *, owner=None,
closure = tuple(new_closure) closure = tuple(new_closure)
else: else:
closure = None closure = None
func = types.FunctionType(annotate.__code__, globals, closure=closure, func = types.FunctionType(
argdefs=annotate.__defaults__, kwdefaults=annotate.__kwdefaults__) annotate.__code__,
globals,
closure=closure,
argdefs=annotate.__defaults__,
kwdefaults=annotate.__kwdefaults__,
)
result = func(Format.VALUE) result = func(Format.VALUE)
for obj in globals.stringifiers: for obj in globals.stringifiers:
obj.__class__ = ForwardRef obj.__class__ = ForwardRef

View file

@ -24,8 +24,6 @@ Here are some of the useful functions provided by this module:
stack(), trace() - get info about frames on the stack or in a traceback stack(), trace() - get info about frames on the stack or in a traceback
signature() - get a Signature object for the callable signature() - get a Signature object for the callable
get_annotations() - safely compute an object's annotations
""" """
# This module is in the public domain. No warranties. # This module is in the public domain. No warranties.
@ -142,7 +140,7 @@ __all__ = [
import abc import abc
from annotationlib import get_annotations from annotationlib import get_annotations # re-exported
import ast import ast
import dis import dis
import collections.abc import collections.abc

View file

@ -8,6 +8,7 @@ import unittest
from annotationlib import Format, ForwardRef, get_annotations, get_annotate_function from annotationlib import Format, ForwardRef, get_annotations, get_annotate_function
from typing import Unpack from typing import Unpack
from test import support
from test.test_inspect import inspect_stock_annotations from test.test_inspect import inspect_stock_annotations
from test.test_inspect import inspect_stringized_annotations from test.test_inspect import inspect_stringized_annotations
from test.test_inspect import inspect_stringized_annotations_2 from test.test_inspect import inspect_stringized_annotations_2
@ -327,7 +328,9 @@ class TestGetAnnotations(unittest.TestCase):
) )
self.assertEqual(annotationlib.get_annotations(NoDict), {"b": str}) self.assertEqual(annotationlib.get_annotations(NoDict), {"b": str})
self.assertEqual( self.assertEqual(
annotationlib.get_annotations(NoDict, format=annotationlib.Format.FORWARDREF), annotationlib.get_annotations(
NoDict, format=annotationlib.Format.FORWARDREF
),
{"b": str}, {"b": str},
) )
self.assertEqual( self.assertEqual(
@ -715,12 +718,13 @@ class TestGetAnnotations(unittest.TestCase):
) )
self.assertEqual(B_annotations, {"x": int, "y": str, "z": bytes}) self.assertEqual(B_annotations, {"x": int, "y": str, "z": bytes})
def test_pep695_generic_class_with_future_annotations_name_clash_with_global_vars(self): def test_pep695_generic_class_with_future_annotations_name_clash_with_global_vars(
self,
):
ann_module695 = inspect_stringized_annotations_pep695 ann_module695 = inspect_stringized_annotations_pep695
C_annotations = annotationlib.get_annotations(ann_module695.C, eval_str=True) C_annotations = annotationlib.get_annotations(ann_module695.C, eval_str=True)
self.assertEqual( self.assertEqual(
set(C_annotations.values()), set(C_annotations.values()), set(ann_module695.C.__type_params__)
set(ann_module695.C.__type_params__)
) )
def test_pep_695_generic_function_with_future_annotations(self): def test_pep_695_generic_function_with_future_annotations(self):
@ -737,17 +741,19 @@ class TestGetAnnotations(unittest.TestCase):
self.assertIs(generic_func_annotations["z"].__origin__, func_t_params[2]) self.assertIs(generic_func_annotations["z"].__origin__, func_t_params[2])
self.assertIs(generic_func_annotations["zz"].__origin__, func_t_params[2]) self.assertIs(generic_func_annotations["zz"].__origin__, func_t_params[2])
def test_pep_695_generic_function_with_future_annotations_name_clash_with_global_vars(self): def test_pep_695_generic_function_with_future_annotations_name_clash_with_global_vars(
self,
):
self.assertEqual( self.assertEqual(
set( set(
annotationlib.get_annotations( annotationlib.get_annotations(
inspect_stringized_annotations_pep695.generic_function_2, inspect_stringized_annotations_pep695.generic_function_2,
eval_str=True eval_str=True,
).values() ).values()
), ),
set( set(
inspect_stringized_annotations_pep695.generic_function_2.__type_params__ inspect_stringized_annotations_pep695.generic_function_2.__type_params__
) ),
) )
def test_pep_695_generic_method_with_future_annotations(self): def test_pep_695_generic_method_with_future_annotations(self):
@ -761,23 +767,27 @@ class TestGetAnnotations(unittest.TestCase):
} }
self.assertEqual( self.assertEqual(
generic_method_annotations, generic_method_annotations,
{"x": params["Foo"], "y": params["Bar"], "return": None} {"x": params["Foo"], "y": params["Bar"], "return": None},
) )
def test_pep_695_generic_method_with_future_annotations_name_clash_with_global_vars(self): def test_pep_695_generic_method_with_future_annotations_name_clash_with_global_vars(
self,
):
self.assertEqual( self.assertEqual(
set( set(
annotationlib.get_annotations( annotationlib.get_annotations(
inspect_stringized_annotations_pep695.D.generic_method_2, inspect_stringized_annotations_pep695.D.generic_method_2,
eval_str=True eval_str=True,
).values() ).values()
), ),
set( set(
inspect_stringized_annotations_pep695.D.generic_method_2.__type_params__ inspect_stringized_annotations_pep695.D.generic_method_2.__type_params__
) ),
) )
def test_pep_695_generic_method_with_future_annotations_name_clash_with_global_and_local_vars(self): def test_pep_695_generic_method_with_future_annotations_name_clash_with_global_and_local_vars(
self,
):
self.assertEqual( self.assertEqual(
annotationlib.get_annotations( annotationlib.get_annotations(
inspect_stringized_annotations_pep695.E, eval_str=True inspect_stringized_annotations_pep695.E, eval_str=True
@ -789,20 +799,20 @@ class TestGetAnnotations(unittest.TestCase):
results = inspect_stringized_annotations_pep695.nested() results = inspect_stringized_annotations_pep695.nested()
self.assertEqual( self.assertEqual(
set(results.F_annotations.values()), set(results.F_annotations.values()), set(results.F.__type_params__)
set(results.F.__type_params__)
) )
self.assertEqual( self.assertEqual(
set(results.F_meth_annotations.values()), set(results.F_meth_annotations.values()),
set(results.F.generic_method.__type_params__) set(results.F.generic_method.__type_params__),
) )
self.assertNotEqual( self.assertNotEqual(
set(results.F_meth_annotations.values()), set(results.F_meth_annotations.values()), set(results.F.__type_params__)
set(results.F.__type_params__)
) )
self.assertEqual( self.assertEqual(
set(results.F_meth_annotations.values()).intersection(results.F.__type_params__), set(results.F_meth_annotations.values()).intersection(
set() results.F.__type_params__
),
set(),
) )
self.assertEqual(results.G_annotations, {"x": str}) self.assertEqual(results.G_annotations, {"x": str})
@ -823,7 +833,9 @@ class TestCallEvaluateFunction(unittest.TestCase):
with self.assertRaises(NameError): with self.assertRaises(NameError):
annotationlib.call_evaluate_function(evaluate, annotationlib.Format.VALUE) annotationlib.call_evaluate_function(evaluate, annotationlib.Format.VALUE)
self.assertEqual( self.assertEqual(
annotationlib.call_evaluate_function(evaluate, annotationlib.Format.FORWARDREF), annotationlib.call_evaluate_function(
evaluate, annotationlib.Format.FORWARDREF
),
annotationlib.ForwardRef("undefined"), annotationlib.ForwardRef("undefined"),
) )
self.assertEqual( self.assertEqual(
@ -853,12 +865,14 @@ class MetaclassTests(unittest.TestCase):
self.assertEqual(get_annotate_function(Y)(Format.VALUE), {"b": float}) self.assertEqual(get_annotate_function(Y)(Format.VALUE), {"b": float})
def test_unannotated_meta(self): def test_unannotated_meta(self):
class Meta(type): pass class Meta(type):
pass
class X(metaclass=Meta): class X(metaclass=Meta):
a: str a: str
class Y(X): pass class Y(X):
pass
self.assertEqual(get_annotations(Meta), {}) self.assertEqual(get_annotations(Meta), {})
self.assertIs(get_annotate_function(Meta), None) self.assertIs(get_annotate_function(Meta), None)
@ -907,6 +921,13 @@ class MetaclassTests(unittest.TestCase):
self.assertEqual(get_annotations(c), c.expected_annotations) self.assertEqual(get_annotations(c), c.expected_annotations)
annotate_func = get_annotate_function(c) annotate_func = get_annotate_function(c)
if c.expected_annotations: if c.expected_annotations:
self.assertEqual(annotate_func(Format.VALUE), c.expected_annotations) self.assertEqual(
annotate_func(Format.VALUE), c.expected_annotations
)
else: else:
self.assertIs(annotate_func, None) self.assertIs(annotate_func, None)
class TestAnnotationLib(unittest.TestCase):
def test__all__(self):
support.check__all__(self, annotationlib)