gh-132805: annotationlib: Fix handling of non-constant values in FORWARDREF (#132812)

Co-authored-by: David C Ellis <ducksual@gmail.com>
This commit is contained in:
Jelle Zijlstra 2025-05-04 08:49:13 -07:00 committed by GitHub
parent 7cb86c5def
commit c8f233c53b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 251 additions and 44 deletions

View file

@ -38,6 +38,7 @@ _SLOTS = (
"__weakref__", "__weakref__",
"__arg__", "__arg__",
"__globals__", "__globals__",
"__extra_names__",
"__code__", "__code__",
"__ast_node__", "__ast_node__",
"__cell__", "__cell__",
@ -82,6 +83,7 @@ class ForwardRef:
# is created through __class__ assignment on a _Stringifier object. # is created through __class__ assignment on a _Stringifier object.
self.__globals__ = None self.__globals__ = None
self.__cell__ = None self.__cell__ = None
self.__extra_names__ = None
# These are initially None but serve as a cache and may be set to a non-None # These are initially None but serve as a cache and may be set to a non-None
# value later. # value later.
self.__code__ = None self.__code__ = None
@ -151,6 +153,8 @@ class ForwardRef:
if not self.__forward_is_class__ or param_name not in globals: if not self.__forward_is_class__ or param_name not in globals:
globals[param_name] = param globals[param_name] = param
locals.pop(param_name, None) locals.pop(param_name, None)
if self.__extra_names__:
locals = {**locals, **self.__extra_names__}
arg = self.__forward_arg__ arg = self.__forward_arg__
if arg.isidentifier() and not keyword.iskeyword(arg): if arg.isidentifier() and not keyword.iskeyword(arg):
@ -231,6 +235,10 @@ class ForwardRef:
and self.__forward_is_class__ == other.__forward_is_class__ and self.__forward_is_class__ == other.__forward_is_class__
and self.__cell__ == other.__cell__ and self.__cell__ == other.__cell__
and self.__owner__ == other.__owner__ and self.__owner__ == other.__owner__
and (
(tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None) ==
(tuple(sorted(other.__extra_names__.items())) if other.__extra_names__ else None)
)
) )
def __hash__(self): def __hash__(self):
@ -241,6 +249,7 @@ class ForwardRef:
self.__forward_is_class__, self.__forward_is_class__,
self.__cell__, self.__cell__,
self.__owner__, self.__owner__,
tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None,
)) ))
def __or__(self, other): def __or__(self, other):
@ -274,6 +283,7 @@ class _Stringifier:
cell=None, cell=None,
*, *,
stringifier_dict, stringifier_dict,
extra_names=None,
): ):
# Either an AST node or a simple str (for the common case where a ForwardRef # Either an AST node or a simple str (for the common case where a ForwardRef
# represent a single name). # represent a single name).
@ -285,6 +295,7 @@ class _Stringifier:
self.__code__ = None self.__code__ = None
self.__ast_node__ = node self.__ast_node__ = node
self.__globals__ = globals self.__globals__ = globals
self.__extra_names__ = extra_names
self.__cell__ = cell self.__cell__ = cell
self.__owner__ = owner self.__owner__ = owner
self.__stringifier_dict__ = stringifier_dict self.__stringifier_dict__ = stringifier_dict
@ -292,28 +303,63 @@ class _Stringifier:
def __convert_to_ast(self, other): def __convert_to_ast(self, other):
if isinstance(other, _Stringifier): if isinstance(other, _Stringifier):
if isinstance(other.__ast_node__, str): if isinstance(other.__ast_node__, str):
return ast.Name(id=other.__ast_node__) return ast.Name(id=other.__ast_node__), other.__extra_names__
return other.__ast_node__ return other.__ast_node__, other.__extra_names__
elif isinstance(other, slice): elif (
return ast.Slice( # In STRING format we don't bother with the create_unique_name() dance;
lower=( # it's better to emit the repr() of the object instead of an opaque name.
self.__convert_to_ast(other.start) self.__stringifier_dict__.format == Format.STRING
if other.start is not None or other is None
else None or type(other) in (str, int, float, bool, complex)
), ):
upper=( return ast.Constant(value=other), None
self.__convert_to_ast(other.stop) elif type(other) is dict:
if other.stop is not None extra_names = {}
else None keys = []
), values = []
step=( for key, value in other.items():
self.__convert_to_ast(other.step) new_key, new_extra_names = self.__convert_to_ast(key)
if other.step is not None if new_extra_names is not None:
else None extra_names.update(new_extra_names)
), keys.append(new_key)
) new_value, new_extra_names = self.__convert_to_ast(value)
if new_extra_names is not None:
extra_names.update(new_extra_names)
values.append(new_value)
return ast.Dict(keys, values), extra_names
elif type(other) in (list, tuple, set):
extra_names = {}
elts = []
for elt in other:
new_elt, new_extra_names = self.__convert_to_ast(elt)
if new_extra_names is not None:
extra_names.update(new_extra_names)
elts.append(new_elt)
ast_class = {list: ast.List, tuple: ast.Tuple, set: ast.Set}[type(other)]
return ast_class(elts), extra_names
else: else:
return ast.Constant(value=other) name = self.__stringifier_dict__.create_unique_name()
return ast.Name(id=name), {name: other}
def __convert_to_ast_getitem(self, other):
if isinstance(other, slice):
extra_names = {}
def conv(obj):
if obj is None:
return None
new_obj, new_extra_names = self.__convert_to_ast(obj)
if new_extra_names is not None:
extra_names.update(new_extra_names)
return new_obj
return ast.Slice(
lower=conv(other.start),
upper=conv(other.stop),
step=conv(other.step),
), extra_names
else:
return self.__convert_to_ast(other)
def __get_ast(self): def __get_ast(self):
node = self.__ast_node__ node = self.__ast_node__
@ -321,13 +367,19 @@ class _Stringifier:
return ast.Name(id=node) return ast.Name(id=node)
return node return node
def __make_new(self, node): def __make_new(self, node, extra_names=None):
new_extra_names = {}
if self.__extra_names__ is not None:
new_extra_names.update(self.__extra_names__)
if extra_names is not None:
new_extra_names.update(extra_names)
stringifier = _Stringifier( stringifier = _Stringifier(
node, node,
self.__globals__, self.__globals__,
self.__owner__, self.__owner__,
self.__forward_is_class__, self.__forward_is_class__,
stringifier_dict=self.__stringifier_dict__, stringifier_dict=self.__stringifier_dict__,
extra_names=new_extra_names or None,
) )
self.__stringifier_dict__.stringifiers.append(stringifier) self.__stringifier_dict__.stringifiers.append(stringifier)
return stringifier return stringifier
@ -343,27 +395,37 @@ class _Stringifier:
if self.__ast_node__ == "__classdict__": if self.__ast_node__ == "__classdict__":
raise KeyError raise KeyError
if isinstance(other, tuple): if isinstance(other, tuple):
elts = [self.__convert_to_ast(elt) for elt in other] extra_names = {}
elts = []
for elt in other:
new_elt, new_extra_names = self.__convert_to_ast_getitem(elt)
if new_extra_names is not None:
extra_names.update(new_extra_names)
elts.append(new_elt)
other = ast.Tuple(elts) other = ast.Tuple(elts)
else: else:
other = self.__convert_to_ast(other) other, extra_names = self.__convert_to_ast_getitem(other)
assert isinstance(other, ast.AST), repr(other) assert isinstance(other, ast.AST), repr(other)
return self.__make_new(ast.Subscript(self.__get_ast(), other)) return self.__make_new(ast.Subscript(self.__get_ast(), other), extra_names)
def __getattr__(self, attr): def __getattr__(self, attr):
return self.__make_new(ast.Attribute(self.__get_ast(), attr)) return self.__make_new(ast.Attribute(self.__get_ast(), attr))
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.__make_new( extra_names = {}
ast.Call( ast_args = []
self.__get_ast(), for arg in args:
[self.__convert_to_ast(arg) for arg in args], new_arg, new_extra_names = self.__convert_to_ast(arg)
[ if new_extra_names is not None:
ast.keyword(key, self.__convert_to_ast(value)) extra_names.update(new_extra_names)
for key, value in kwargs.items() ast_args.append(new_arg)
], ast_kwargs = []
) for key, value in kwargs.items():
) new_value, new_extra_names = self.__convert_to_ast(value)
if new_extra_names is not None:
extra_names.update(new_extra_names)
ast_kwargs.append(ast.keyword(key, new_value))
return self.__make_new(ast.Call(self.__get_ast(), ast_args, ast_kwargs), extra_names)
def __iter__(self): def __iter__(self):
yield self.__make_new(ast.Starred(self.__get_ast())) yield self.__make_new(ast.Starred(self.__get_ast()))
@ -378,8 +440,9 @@ class _Stringifier:
def _make_binop(op: ast.AST): def _make_binop(op: ast.AST):
def binop(self, other): def binop(self, other):
rhs, extra_names = self.__convert_to_ast(other)
return self.__make_new( return self.__make_new(
ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other)) ast.BinOp(self.__get_ast(), op, rhs), extra_names
) )
return binop return binop
@ -402,8 +465,9 @@ class _Stringifier:
def _make_rbinop(op: ast.AST): def _make_rbinop(op: ast.AST):
def rbinop(self, other): def rbinop(self, other):
new_other, extra_names = self.__convert_to_ast(other)
return self.__make_new( return self.__make_new(
ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast()) ast.BinOp(new_other, op, self.__get_ast()), extra_names
) )
return rbinop return rbinop
@ -426,12 +490,14 @@ class _Stringifier:
def _make_compare(op): def _make_compare(op):
def compare(self, other): def compare(self, other):
rhs, extra_names = self.__convert_to_ast(other)
return self.__make_new( return self.__make_new(
ast.Compare( ast.Compare(
left=self.__get_ast(), left=self.__get_ast(),
ops=[op], ops=[op],
comparators=[self.__convert_to_ast(other)], comparators=[rhs],
) ),
extra_names,
) )
return compare return compare
@ -459,13 +525,15 @@ class _Stringifier:
class _StringifierDict(dict): class _StringifierDict(dict):
def __init__(self, namespace, globals=None, owner=None, is_class=False): def __init__(self, namespace, *, globals=None, owner=None, is_class=False, format):
super().__init__(namespace) super().__init__(namespace)
self.namespace = namespace self.namespace = namespace
self.globals = globals self.globals = globals
self.owner = owner self.owner = owner
self.is_class = is_class self.is_class = is_class
self.stringifiers = [] self.stringifiers = []
self.next_id = 1
self.format = format
def __missing__(self, key): def __missing__(self, key):
fwdref = _Stringifier( fwdref = _Stringifier(
@ -478,6 +546,11 @@ class _StringifierDict(dict):
self.stringifiers.append(fwdref) self.stringifiers.append(fwdref)
return fwdref return fwdref
def create_unique_name(self):
name = f"__annotationlib_name_{self.next_id}__"
self.next_id += 1
return name
def call_evaluate_function(evaluate, format, *, owner=None): def call_evaluate_function(evaluate, format, *, owner=None):
"""Call an evaluate function. Evaluate functions are normally generated for """Call an evaluate function. Evaluate functions are normally generated for
@ -521,7 +594,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
# possibly constants if the annotate function uses them directly). We then # possibly constants if the annotate function uses them directly). We then
# convert each of those into a string to get an approximation of the # convert each of those into a string to get an approximation of the
# original source. # original source.
globals = _StringifierDict({}) globals = _StringifierDict({}, format=format)
if annotate.__closure__: if annotate.__closure__:
freevars = annotate.__code__.co_freevars freevars = annotate.__code__.co_freevars
new_closure = [] new_closure = []
@ -544,9 +617,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
) )
annos = func(Format.VALUE_WITH_FAKE_GLOBALS) annos = func(Format.VALUE_WITH_FAKE_GLOBALS)
if _is_evaluate: if _is_evaluate:
return annos if isinstance(annos, str) else repr(annos) return _stringify_single(annos)
return { return {
key: val if isinstance(val, str) else repr(val) key: _stringify_single(val)
for key, val in annos.items() for key, val in annos.items()
} }
elif format == Format.FORWARDREF: elif format == Format.FORWARDREF:
@ -569,7 +642,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
# that returns a bool and an defined set of attributes. # that returns a bool and an defined set of attributes.
namespace = {**annotate.__builtins__, **annotate.__globals__} namespace = {**annotate.__builtins__, **annotate.__globals__}
is_class = isinstance(owner, type) is_class = isinstance(owner, type)
globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class) globals = _StringifierDict(
namespace,
globals=annotate.__globals__,
owner=owner,
is_class=is_class,
format=format,
)
if annotate.__closure__: if annotate.__closure__:
freevars = annotate.__code__.co_freevars freevars = annotate.__code__.co_freevars
new_closure = [] new_closure = []
@ -619,6 +698,16 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
raise ValueError(f"Invalid format: {format!r}") raise ValueError(f"Invalid format: {format!r}")
def _stringify_single(anno):
if anno is ...:
return "..."
# We have to handle str specially to support PEP 563 stringified annotations.
elif isinstance(anno, str):
return anno
else:
return repr(anno)
def get_annotate_from_class_namespace(obj): def get_annotate_from_class_namespace(obj):
"""Retrieve the annotate function from a class namespace dictionary. """Retrieve the annotate function from a class namespace dictionary.

View file

@ -121,6 +121,28 @@ class TestForwardRefFormat(unittest.TestCase):
self.assertIsInstance(gamma_anno, ForwardRef) self.assertIsInstance(gamma_anno, ForwardRef)
self.assertEqual(gamma_anno, support.EqualToForwardRef("some < obj", owner=f)) self.assertEqual(gamma_anno, support.EqualToForwardRef("some < obj", owner=f))
def test_partially_nonexistent_union(self):
# Test unions with '|' syntax equal unions with typing.Union[] with some forwardrefs
class UnionForwardrefs:
pipe: str | undefined
union: Union[str, undefined]
annos = get_annotations(UnionForwardrefs, format=Format.FORWARDREF)
pipe = annos["pipe"]
self.assertIsInstance(pipe, ForwardRef)
self.assertEqual(
pipe.evaluate(globals={"undefined": int}),
str | int,
)
union = annos["union"]
self.assertIsInstance(union, Union)
arg1, arg2 = typing.get_args(union)
self.assertIs(arg1, str)
self.assertEqual(
arg2, support.EqualToForwardRef("undefined", is_class=True, owner=UnionForwardrefs)
)
class TestStringFormat(unittest.TestCase): class TestStringFormat(unittest.TestCase):
def test_closure(self): def test_closure(self):
@ -251,6 +273,89 @@ class TestStringFormat(unittest.TestCase):
}, },
) )
def test_getitem(self):
def f(x: undef1[str, undef2]):
pass
anno = annotationlib.get_annotations(f, format=Format.STRING)
self.assertEqual(anno, {"x": "undef1[str, undef2]"})
anno = annotationlib.get_annotations(f, format=Format.FORWARDREF)
fwdref = anno["x"]
self.assertIsInstance(fwdref, ForwardRef)
self.assertEqual(
fwdref.evaluate(globals={"undef1": dict, "undef2": float}), dict[str, float]
)
def test_slice(self):
def f(x: a[b:c]):
pass
anno = annotationlib.get_annotations(f, format=Format.STRING)
self.assertEqual(anno, {"x": "a[b:c]"})
def f(x: a[b:c, d:e]):
pass
anno = annotationlib.get_annotations(f, format=Format.STRING)
self.assertEqual(anno, {"x": "a[b:c, d:e]"})
obj = slice(1, 1, 1)
def f(x: obj):
pass
anno = annotationlib.get_annotations(f, format=Format.STRING)
self.assertEqual(anno, {"x": "obj"})
def test_literals(self):
def f(
a: 1,
b: 1.0,
c: "hello",
d: b"hello",
e: True,
f: None,
g: ...,
h: 1j,
):
pass
anno = annotationlib.get_annotations(f, format=Format.STRING)
self.assertEqual(
anno,
{
"a": "1",
"b": "1.0",
"c": 'hello',
"d": "b'hello'",
"e": "True",
"f": "None",
"g": "...",
"h": "1j",
},
)
def test_displays(self):
# Simple case first
def f(x: a[[int, str], float]):
pass
anno = annotationlib.get_annotations(f, format=Format.STRING)
self.assertEqual(anno, {"x": "a[[int, str], float]"})
def g(
w: a[[int, str], float],
x: a[{int, str}, 3],
y: a[{int: str}, 4],
z: a[(int, str), 5],
):
pass
anno = annotationlib.get_annotations(g, format=Format.STRING)
self.assertEqual(
anno,
{
"w": "a[[int, str], float]",
"x": "a[{int, str}, 3]",
"y": "a[{int: str}, 4]",
"z": "a[(int, str), 5]",
},
)
def test_nested_expressions(self): def test_nested_expressions(self):
def f( def f(
nested: list[Annotated[set[int], "set of ints", 4j]], nested: list[Annotated[set[int], "set of ints", 4j]],
@ -296,6 +401,17 @@ class TestStringFormat(unittest.TestCase):
with self.assertRaisesRegex(TypeError, format_msg): with self.assertRaisesRegex(TypeError, format_msg):
get_annotations(f, format=Format.STRING) get_annotations(f, format=Format.STRING)
def test_shenanigans(self):
# In cases like this we can't reconstruct the source; test that we do something
# halfway reasonable.
def f(x: x | (1).__class__, y: (1).__class__):
pass
self.assertEqual(
get_annotations(f, format=Format.STRING),
{"x": "x | <class 'int'>", "y": "<class 'int'>"},
)
class TestGetAnnotations(unittest.TestCase): class TestGetAnnotations(unittest.TestCase):
def test_builtin_type(self): def test_builtin_type(self):

View file

@ -0,0 +1,2 @@
Fix incorrect handling of nested non-constant values in the FORWARDREF
format in :mod:`annotationlib`.