gh-125618: Make FORWARDREF format succeed more often (#132818)

Fixes #125618.
This commit is contained in:
Jelle Zijlstra 2025-05-04 15:21:56 -07:00 committed by GitHub
parent 3109c47be8
commit af5799f305
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 237 additions and 63 deletions

View file

@ -132,7 +132,7 @@ Classes
Values are real annotation values (as per :attr:`Format.VALUE` format) Values are real annotation values (as per :attr:`Format.VALUE` format)
for defined values, and :class:`ForwardRef` proxies for undefined for defined values, and :class:`ForwardRef` proxies for undefined
values. Real objects may contain references to, :class:`ForwardRef` values. Real objects may contain references to :class:`ForwardRef`
proxy objects. proxy objects.
.. attribute:: STRING .. attribute:: STRING
@ -172,14 +172,21 @@ Classes
:class:`~ForwardRef`. The string may not be exactly equivalent :class:`~ForwardRef`. The string may not be exactly equivalent
to the original source. to the original source.
.. method:: evaluate(*, owner=None, globals=None, locals=None, type_params=None) .. method:: evaluate(*, owner=None, globals=None, locals=None, type_params=None, format=Format.VALUE)
Evaluate the forward reference, returning its value. Evaluate the forward reference, returning its value.
This may throw an exception, such as :exc:`NameError`, if the forward If the *format* argument is :attr:`~Format.VALUE` (the default),
this method may throw an exception, such as :exc:`NameError`, if the forward
reference refers to a name that cannot be resolved. The arguments to this reference refers to a name that cannot be resolved. The arguments to this
method can be used to provide bindings for names that would otherwise method can be used to provide bindings for names that would otherwise
be undefined. be undefined. If the *format* argument is :attr:`~Format.FORWARDREF`,
the method will never throw an exception, but may return a :class:`~ForwardRef`
instance. For example, if the forward reference object contains the code
``list[undefined]``, where ``undefined`` is a name that is not defined,
evaluating it with the :attr:`~Format.FORWARDREF` format will return
``list[ForwardRef('undefined')]``. If the *format* argument is
:attr:`~Format.STRING`, the method will return :attr:`~ForwardRef.__forward_arg__`.
The *owner* parameter provides the preferred mechanism for passing scope The *owner* parameter provides the preferred mechanism for passing scope
information to this method. The owner of a :class:`~ForwardRef` is the information to this method. The owner of a :class:`~ForwardRef` is the

View file

@ -92,11 +92,28 @@ class ForwardRef:
def __init_subclass__(cls, /, *args, **kwds): def __init_subclass__(cls, /, *args, **kwds):
raise TypeError("Cannot subclass ForwardRef") raise TypeError("Cannot subclass ForwardRef")
def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): def evaluate(
self,
*,
globals=None,
locals=None,
type_params=None,
owner=None,
format=Format.VALUE,
):
"""Evaluate the forward reference and return the value. """Evaluate the forward reference and return the value.
If the forward reference cannot be evaluated, raise an exception. If the forward reference cannot be evaluated, raise an exception.
""" """
match format:
case Format.STRING:
return self.__forward_arg__
case Format.VALUE:
is_forwardref_format = False
case Format.FORWARDREF:
is_forwardref_format = True
case _:
raise NotImplementedError(format)
if self.__cell__ is not None: if self.__cell__ is not None:
try: try:
return self.__cell__.cell_contents return self.__cell__.cell_contents
@ -159,17 +176,36 @@ class ForwardRef:
arg = self.__forward_arg__ arg = self.__forward_arg__
if arg.isidentifier() and not keyword.iskeyword(arg): if arg.isidentifier() and not keyword.iskeyword(arg):
if arg in locals: if arg in locals:
value = locals[arg] return locals[arg]
elif arg in globals: elif arg in globals:
value = globals[arg] return globals[arg]
elif hasattr(builtins, arg): elif hasattr(builtins, arg):
return getattr(builtins, arg) return getattr(builtins, arg)
elif is_forwardref_format:
return self
else: else:
raise NameError(arg) raise NameError(arg)
else: else:
code = self.__forward_code__ code = self.__forward_code__
value = eval(code, globals=globals, locals=locals) try:
return value return eval(code, globals=globals, locals=locals)
except Exception:
if not is_forwardref_format:
raise
new_locals = _StringifierDict(
{**builtins.__dict__, **locals},
globals=globals,
owner=owner,
is_class=self.__forward_is_class__,
format=format,
)
try:
result = eval(code, globals=globals, locals=new_locals)
except Exception:
return self
else:
new_locals.transmogrify()
return result
def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard): def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard):
import typing import typing
@ -546,6 +582,14 @@ class _StringifierDict(dict):
self.stringifiers.append(fwdref) self.stringifiers.append(fwdref)
return fwdref return fwdref
def transmogrify(self):
for obj in self.stringifiers:
obj.__class__ = ForwardRef
obj.__stringifier_dict__ = None # not needed for ForwardRef
if isinstance(obj.__ast_node__, str):
obj.__arg__ = obj.__ast_node__
obj.__ast_node__ = None
def create_unique_name(self): def create_unique_name(self):
name = f"__annotationlib_name_{self.next_id}__" name = f"__annotationlib_name_{self.next_id}__"
self.next_id += 1 self.next_id += 1
@ -595,19 +639,10 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
# convert each of those into a string to get an approximation of the # convert each of those into a string to get an approximation of the
# original source. # original source.
globals = _StringifierDict({}, format=format) globals = _StringifierDict({}, format=format)
if annotate.__closure__: is_class = isinstance(owner, type)
freevars = annotate.__code__.co_freevars closure = _build_closure(
new_closure = [] annotate, owner, is_class, globals, allow_evaluation=False
for i, cell in enumerate(annotate.__closure__): )
if i < len(freevars):
name = freevars[i]
else:
name = "__cell__"
fwdref = _Stringifier(name, stringifier_dict=globals)
new_closure.append(types.CellType(fwdref))
closure = tuple(new_closure)
else:
closure = None
func = types.FunctionType( func = types.FunctionType(
annotate.__code__, annotate.__code__,
globals, globals,
@ -649,32 +684,36 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
is_class=is_class, is_class=is_class,
format=format, format=format,
) )
if annotate.__closure__: closure = _build_closure(
freevars = annotate.__code__.co_freevars annotate, owner, is_class, globals, allow_evaluation=True
new_closure = [] )
for i, cell in enumerate(annotate.__closure__): func = types.FunctionType(
try: annotate.__code__,
cell.cell_contents globals,
except ValueError: closure=closure,
if i < len(freevars): argdefs=annotate.__defaults__,
name = freevars[i] kwdefaults=annotate.__kwdefaults__,
else: )
name = "__cell__" try:
fwdref = _Stringifier( result = func(Format.VALUE_WITH_FAKE_GLOBALS)
name, except Exception:
cell=cell, pass
owner=owner,
globals=annotate.__globals__,
is_class=is_class,
stringifier_dict=globals,
)
globals.stringifiers.append(fwdref)
new_closure.append(types.CellType(fwdref))
else:
new_closure.append(cell)
closure = tuple(new_closure)
else: else:
closure = None globals.transmogrify()
return result
# Try again, but do not provide any globals. This allows us to return
# a value in certain cases where an exception gets raised during evaluation.
globals = _StringifierDict(
{},
globals=annotate.__globals__,
owner=owner,
is_class=is_class,
format=format,
)
closure = _build_closure(
annotate, owner, is_class, globals, allow_evaluation=False
)
func = types.FunctionType( func = types.FunctionType(
annotate.__code__, annotate.__code__,
globals, globals,
@ -683,13 +722,21 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
kwdefaults=annotate.__kwdefaults__, kwdefaults=annotate.__kwdefaults__,
) )
result = func(Format.VALUE_WITH_FAKE_GLOBALS) result = func(Format.VALUE_WITH_FAKE_GLOBALS)
for obj in globals.stringifiers: globals.transmogrify()
obj.__class__ = ForwardRef if _is_evaluate:
obj.__stringifier_dict__ = None # not needed for ForwardRef if isinstance(result, ForwardRef):
if isinstance(obj.__ast_node__, str): return result.evaluate(format=Format.FORWARDREF)
obj.__arg__ = obj.__ast_node__ else:
obj.__ast_node__ = None return result
return result else:
return {
key: (
val.evaluate(format=Format.FORWARDREF)
if isinstance(val, ForwardRef)
else val
)
for key, val in result.items()
}
elif format == Format.VALUE: elif format == Format.VALUE:
# Should be impossible because __annotate__ functions must not raise # Should be impossible because __annotate__ functions must not raise
# NotImplementedError for this format. # NotImplementedError for this format.
@ -698,6 +745,39 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
raise ValueError(f"Invalid format: {format!r}") raise ValueError(f"Invalid format: {format!r}")
def _build_closure(annotate, owner, is_class, stringifier_dict, *, allow_evaluation):
if not annotate.__closure__:
return None
freevars = annotate.__code__.co_freevars
new_closure = []
for i, cell in enumerate(annotate.__closure__):
if i < len(freevars):
name = freevars[i]
else:
name = "__cell__"
new_cell = None
if allow_evaluation:
try:
cell.cell_contents
except ValueError:
pass
else:
new_cell = cell
if new_cell is None:
fwdref = _Stringifier(
name,
cell=cell,
owner=owner,
globals=annotate.__globals__,
is_class=is_class,
stringifier_dict=stringifier_dict,
)
stringifier_dict.stringifiers.append(fwdref)
new_cell = types.CellType(fwdref)
new_closure.append(new_cell)
return tuple(new_closure)
def _stringify_single(anno): def _stringify_single(anno):
if anno is ...: if anno is ...:
return "..." return "..."
@ -809,7 +889,7 @@ def get_annotations(
# But if we didn't get it, we use __annotations__ instead. # But if we didn't get it, we use __annotations__ instead.
ann = _get_dunder_annotations(obj) ann = _get_dunder_annotations(obj)
if ann is not None: if ann is not None:
return annotations_to_string(ann) return annotations_to_string(ann)
case Format.VALUE_WITH_FAKE_GLOBALS: case Format.VALUE_WITH_FAKE_GLOBALS:
raise ValueError("The VALUE_WITH_FAKE_GLOBALS format is for internal use only") raise ValueError("The VALUE_WITH_FAKE_GLOBALS format is for internal use only")
case _: case _:

View file

@ -276,10 +276,10 @@ class TestStringFormat(unittest.TestCase):
def test_getitem(self): def test_getitem(self):
def f(x: undef1[str, undef2]): def f(x: undef1[str, undef2]):
pass pass
anno = annotationlib.get_annotations(f, format=Format.STRING) anno = get_annotations(f, format=Format.STRING)
self.assertEqual(anno, {"x": "undef1[str, undef2]"}) self.assertEqual(anno, {"x": "undef1[str, undef2]"})
anno = annotationlib.get_annotations(f, format=Format.FORWARDREF) anno = get_annotations(f, format=Format.FORWARDREF)
fwdref = anno["x"] fwdref = anno["x"]
self.assertIsInstance(fwdref, ForwardRef) self.assertIsInstance(fwdref, ForwardRef)
self.assertEqual( self.assertEqual(
@ -289,18 +289,18 @@ class TestStringFormat(unittest.TestCase):
def test_slice(self): def test_slice(self):
def f(x: a[b:c]): def f(x: a[b:c]):
pass pass
anno = annotationlib.get_annotations(f, format=Format.STRING) anno = get_annotations(f, format=Format.STRING)
self.assertEqual(anno, {"x": "a[b:c]"}) self.assertEqual(anno, {"x": "a[b:c]"})
def f(x: a[b:c, d:e]): def f(x: a[b:c, d:e]):
pass pass
anno = annotationlib.get_annotations(f, format=Format.STRING) anno = get_annotations(f, format=Format.STRING)
self.assertEqual(anno, {"x": "a[b:c, d:e]"}) self.assertEqual(anno, {"x": "a[b:c, d:e]"})
obj = slice(1, 1, 1) obj = slice(1, 1, 1)
def f(x: obj): def f(x: obj):
pass pass
anno = annotationlib.get_annotations(f, format=Format.STRING) anno = get_annotations(f, format=Format.STRING)
self.assertEqual(anno, {"x": "obj"}) self.assertEqual(anno, {"x": "obj"})
def test_literals(self): def test_literals(self):
@ -316,7 +316,7 @@ class TestStringFormat(unittest.TestCase):
): ):
pass pass
anno = annotationlib.get_annotations(f, format=Format.STRING) anno = get_annotations(f, format=Format.STRING)
self.assertEqual( self.assertEqual(
anno, anno,
{ {
@ -335,7 +335,7 @@ class TestStringFormat(unittest.TestCase):
# Simple case first # Simple case first
def f(x: a[[int, str], float]): def f(x: a[[int, str], float]):
pass pass
anno = annotationlib.get_annotations(f, format=Format.STRING) anno = get_annotations(f, format=Format.STRING)
self.assertEqual(anno, {"x": "a[[int, str], float]"}) self.assertEqual(anno, {"x": "a[[int, str], float]"})
def g( def g(
@ -345,7 +345,7 @@ class TestStringFormat(unittest.TestCase):
z: a[(int, str), 5], z: a[(int, str), 5],
): ):
pass pass
anno = annotationlib.get_annotations(g, format=Format.STRING) anno = get_annotations(g, format=Format.STRING)
self.assertEqual( self.assertEqual(
anno, anno,
{ {
@ -1017,6 +1017,58 @@ class TestGetAnnotations(unittest.TestCase):
set(results.generic_func.__type_params__), set(results.generic_func.__type_params__),
) )
def test_partial_evaluation(self):
def f(
x: builtins.undef,
y: list[int],
z: 1 + int,
a: builtins.int,
b: [builtins.undef, builtins.int],
):
pass
self.assertEqual(
get_annotations(f, format=Format.FORWARDREF),
{
"x": support.EqualToForwardRef("builtins.undef", owner=f),
"y": list[int],
"z": support.EqualToForwardRef("1 + int", owner=f),
"a": int,
"b": [
support.EqualToForwardRef("builtins.undef", owner=f),
# We can't resolve this because we have to evaluate the whole annotation
support.EqualToForwardRef("builtins.int", owner=f),
],
},
)
self.assertEqual(
get_annotations(f, format=Format.STRING),
{
"x": "builtins.undef",
"y": "list[int]",
"z": "1 + int",
"a": "builtins.int",
"b": "[builtins.undef, builtins.int]",
},
)
def test_partial_evaluation_cell(self):
obj = object()
class RaisesAttributeError:
attriberr: obj.missing
anno = get_annotations(RaisesAttributeError, format=Format.FORWARDREF)
self.assertEqual(
anno,
{
"attriberr": support.EqualToForwardRef(
"obj.missing", is_class=True, owner=RaisesAttributeError
)
},
)
class TestCallEvaluateFunction(unittest.TestCase): class TestCallEvaluateFunction(unittest.TestCase):
def test_evaluation(self): def test_evaluation(self):
@ -1370,6 +1422,38 @@ class TestForwardRefClass(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
pickle.dumps(fr, proto) pickle.dumps(fr, proto)
def test_evaluate_string_format(self):
fr = ForwardRef("set[Any]")
self.assertEqual(fr.evaluate(format=Format.STRING), "set[Any]")
def test_evaluate_forwardref_format(self):
fr = ForwardRef("undef")
evaluated = fr.evaluate(format=Format.FORWARDREF)
self.assertIs(fr, evaluated)
fr = ForwardRef("set[undefined]")
evaluated = fr.evaluate(format=Format.FORWARDREF)
self.assertEqual(
evaluated,
set[support.EqualToForwardRef("undefined")],
)
fr = ForwardRef("a + b")
self.assertEqual(
fr.evaluate(format=Format.FORWARDREF),
support.EqualToForwardRef("a + b"),
)
self.assertEqual(
fr.evaluate(format=Format.FORWARDREF, locals={"a": 1, "b": 2}),
3,
)
fr = ForwardRef('"a" + 1')
self.assertEqual(
fr.evaluate(format=Format.FORWARDREF),
support.EqualToForwardRef('"a" + 1'),
)
def test_evaluate_with_type_params(self): def test_evaluate_with_type_params(self):
class Gen[T]: class Gen[T]:
alias = int alias = int

View file

@ -0,0 +1,3 @@
Add a *format* parameter to :meth:`annotationlib.ForwardRef.evaluate`.
Evaluating annotations in the ``FORWARDREF`` format now succeeds in more
cases that would previously have raised an exception.