gh-105858: Improve AST node constructors (#105880)

Demonstration:

>>> ast.FunctionDef.__annotations__
{'name': <class 'str'>, 'args': <class 'ast.arguments'>, 'body': list[ast.stmt], 'decorator_list': list[ast.expr], 'returns': ast.expr | None, 'type_comment': str | None, 'type_params': list[ast.type_param]}
>>> ast.FunctionDef()
<stdin>:1: DeprecationWarning: FunctionDef.__init__ missing 1 required positional argument: 'name'. This will become an error in Python 3.15.
<stdin>:1: DeprecationWarning: FunctionDef.__init__ missing 1 required positional argument: 'args'. This will become an error in Python 3.15.
<ast.FunctionDef object at 0x101959460>
>>> node = ast.FunctionDef(name="foo", args=ast.arguments())
>>> node.decorator_list
[]
>>> ast.FunctionDef(whatever="you want", name="x", args=ast.arguments())
<stdin>:1: DeprecationWarning: FunctionDef.__init__ got an unexpected keyword argument 'whatever'. Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15.
<ast.FunctionDef object at 0x1019581f0>
This commit is contained in:
Jelle Zijlstra 2024-02-27 18:13:03 -08:00 committed by GitHub
parent 5a1559d949
commit ed4dfd8825
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 4675 additions and 49 deletions

View file

@ -15,6 +15,13 @@ TABSIZE = 4
MAX_COL = 80
AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n"
builtin_type_to_c_type = {
"identifier": "PyUnicode_Type",
"string": "PyUnicode_Type",
"int": "PyLong_Type",
"constant": "PyBaseObject_Type",
}
def get_c_type(name):
"""Return a string for the C name of the type.
@ -764,6 +771,67 @@ class PyTypesDeclareVisitor(PickleVisitor):
self.emit("};",0)
class AnnotationsVisitor(PickleVisitor):
def visitModule(self, mod):
self.file.write(textwrap.dedent('''
static int
add_ast_annotations(struct ast_state *state)
{
bool cond;
'''))
for dfn in mod.dfns:
self.visit(dfn)
self.file.write(textwrap.dedent('''
return 1;
}
'''))
def visitProduct(self, prod, name):
self.emit_annotations(name, prod.fields)
def visitSum(self, sum, name):
for t in sum.types:
self.visitConstructor(t, name)
def visitConstructor(self, cons, name):
self.emit_annotations(cons.name, cons.fields)
def emit_annotations(self, name, fields):
self.emit(f"PyObject *{name}_annotations = PyDict_New();", 1)
self.emit(f"if (!{name}_annotations) return 0;", 1)
for field in fields:
self.emit("{", 1)
if field.type in builtin_type_to_c_type:
self.emit(f"PyObject *type = (PyObject *)&{builtin_type_to_c_type[field.type]};", 2)
else:
self.emit(f"PyObject *type = state->{field.type}_type;", 2)
if field.opt:
self.emit("type = _Py_union_type_or(type, Py_None);", 2)
self.emit("cond = type != NULL;", 2)
self.emit_annotations_error(name, 2)
elif field.seq:
self.emit("type = Py_GenericAlias((PyObject *)&PyList_Type, type);", 2)
self.emit("cond = type != NULL;", 2)
self.emit_annotations_error(name, 2)
else:
self.emit("Py_INCREF(type);", 2)
self.emit(f"cond = PyDict_SetItemString({name}_annotations, \"{field.name}\", type) == 0;", 2)
self.emit("Py_DECREF(type);", 2)
self.emit_annotations_error(name, 2)
self.emit("}", 1)
self.emit(f'cond = PyObject_SetAttrString(state->{name}_type, "_field_types", {name}_annotations) == 0;', 1)
self.emit_annotations_error(name, 1)
self.emit(f'cond = PyObject_SetAttrString(state->{name}_type, "__annotations__", {name}_annotations) == 0;', 1)
self.emit_annotations_error(name, 1)
self.emit(f"Py_DECREF({name}_annotations);", 1)
def emit_annotations_error(self, name, depth):
self.emit("if (!cond) {", depth)
self.emit(f"Py_DECREF({name}_annotations);", depth + 1)
self.emit("return 0;", depth + 1)
self.emit("}", depth)
class PyTypesVisitor(PickleVisitor):
def visitModule(self, mod):
@ -812,7 +880,7 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
Py_ssize_t i, numfields = 0;
int res = -1;
PyObject *key, *value, *fields;
PyObject *key, *value, *fields, *remaining_fields = NULL;
if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
goto cleanup;
}
@ -821,6 +889,13 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
if (numfields == -1) {
goto cleanup;
}
remaining_fields = PySet_New(fields);
}
else {
remaining_fields = PySet_New(NULL);
}
if (remaining_fields == NULL) {
goto cleanup;
}
res = 0; /* if no error occurs, this stays 0 to the end */
@ -840,6 +915,11 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
goto cleanup;
}
res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i));
if (PySet_Discard(remaining_fields, name) < 0) {
res = -1;
Py_DECREF(name);
goto cleanup;
}
Py_DECREF(name);
if (res < 0) {
goto cleanup;
@ -852,13 +932,14 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
if (contains == -1) {
res = -1;
goto cleanup;
} else if (contains == 1) {
Py_ssize_t p = PySequence_Index(fields, key);
}
else if (contains == 1) {
int p = PySet_Discard(remaining_fields, key);
if (p == -1) {
res = -1;
goto cleanup;
}
if (p < PyTuple_GET_SIZE(args)) {
if (p == 0) {
PyErr_Format(PyExc_TypeError,
"%.400s got multiple values for argument '%U'",
Py_TYPE(self)->tp_name, key);
@ -866,15 +947,91 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
goto cleanup;
}
}
else if (
PyUnicode_CompareWithASCIIString(key, "lineno") != 0 &&
PyUnicode_CompareWithASCIIString(key, "col_offset") != 0 &&
PyUnicode_CompareWithASCIIString(key, "end_lineno") != 0 &&
PyUnicode_CompareWithASCIIString(key, "end_col_offset") != 0
) {
if (PyErr_WarnFormat(
PyExc_DeprecationWarning, 1,
"%.400s.__init__ got an unexpected keyword argument '%U'. "
"Support for arbitrary keyword arguments is deprecated "
"and will be removed in Python 3.15.",
Py_TYPE(self)->tp_name, key
) < 0) {
res = -1;
goto cleanup;
}
}
res = PyObject_SetAttr(self, key, value);
if (res < 0) {
goto cleanup;
}
}
}
Py_ssize_t size = PySet_Size(remaining_fields);
PyObject *field_types = NULL, *remaining_list = NULL;
if (size > 0) {
if (!PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), &_Py_ID(_field_types),
&field_types)) {
res = -1;
goto cleanup;
}
remaining_list = PySequence_List(remaining_fields);
if (!remaining_list) {
goto set_remaining_cleanup;
}
for (Py_ssize_t i = 0; i < size; i++) {
PyObject *name = PyList_GET_ITEM(remaining_list, i);
PyObject *type = PyDict_GetItemWithError(field_types, name);
if (!type) {
if (!PyErr_Occurred()) {
PyErr_SetObject(PyExc_KeyError, name);
}
goto set_remaining_cleanup;
}
if (_PyUnion_Check(type)) {
// optional field
// do nothing, we'll have set a None default on the class
}
else if (Py_IS_TYPE(type, &Py_GenericAliasType)) {
// list field
PyObject *empty = PyList_New(0);
if (!empty) {
goto set_remaining_cleanup;
}
res = PyObject_SetAttr(self, name, empty);
Py_DECREF(empty);
if (res < 0) {
goto set_remaining_cleanup;
}
}
else {
// simple field (e.g., identifier)
if (PyErr_WarnFormat(
PyExc_DeprecationWarning, 1,
"%.400s.__init__ missing 1 required positional argument: '%U'. "
"This will become an error in Python 3.15.",
Py_TYPE(self)->tp_name, name
) < 0) {
res = -1;
goto cleanup;
}
}
}
Py_DECREF(remaining_list);
Py_DECREF(field_types);
}
cleanup:
Py_XDECREF(fields);
Py_XDECREF(remaining_fields);
return res;
set_remaining_cleanup:
Py_XDECREF(remaining_list);
Py_XDECREF(field_types);
res = -1;
goto cleanup;
}
/* Pickling support */
@ -886,14 +1043,75 @@ ast_type_reduce(PyObject *self, PyObject *unused)
return NULL;
}
PyObject *dict;
PyObject *dict = NULL, *fields = NULL, *remaining_fields = NULL,
*remaining_dict = NULL, *positional_args = NULL;
if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) {
return NULL;
}
PyObject *result = NULL;
if (dict) {
return Py_BuildValue("O()N", Py_TYPE(self), dict);
// Serialize the fields as positional args if possible, because if we
// serialize them as a dict, during unpickling they are set only *after*
// the object is constructed, which will now trigger a DeprecationWarning
// if the AST type has required fields.
if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
goto cleanup;
}
if (fields) {
Py_ssize_t numfields = PySequence_Size(fields);
if (numfields == -1) {
Py_DECREF(dict);
goto cleanup;
}
remaining_dict = PyDict_Copy(dict);
Py_DECREF(dict);
if (!remaining_dict) {
goto cleanup;
}
positional_args = PyList_New(0);
if (!positional_args) {
goto cleanup;
}
for (Py_ssize_t i = 0; i < numfields; i++) {
PyObject *name = PySequence_GetItem(fields, i);
if (!name) {
goto cleanup;
}
PyObject *value = PyDict_GetItemWithError(remaining_dict, name);
if (!value) {
if (PyErr_Occurred()) {
goto cleanup;
}
break;
}
if (PyList_Append(positional_args, value) < 0) {
goto cleanup;
}
if (PyDict_DelItem(remaining_dict, name) < 0) {
goto cleanup;
}
Py_DECREF(name);
}
PyObject *args_tuple = PyList_AsTuple(positional_args);
if (!args_tuple) {
goto cleanup;
}
result = Py_BuildValue("ONO", Py_TYPE(self), args_tuple,
remaining_dict);
}
else {
result = Py_BuildValue("O()N", Py_TYPE(self), dict);
}
}
return Py_BuildValue("O()", Py_TYPE(self));
else {
result = Py_BuildValue("O()", Py_TYPE(self));
}
cleanup:
Py_XDECREF(fields);
Py_XDECREF(remaining_fields);
Py_XDECREF(remaining_dict);
Py_XDECREF(positional_args);
return result;
}
static PyMemberDef ast_type_members[] = {
@ -1117,6 +1335,9 @@ static int add_ast_fields(struct ast_state *state)
for dfn in mod.dfns:
self.visit(dfn)
self.file.write(textwrap.dedent('''
if (!add_ast_annotations(state)) {
return -1;
}
return 0;
}
'''))
@ -1534,6 +1755,8 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "pycore_lock.h" // _PyOnceFlag
#include "pycore_interp.h" // _PyInterpreterState.ast
#include "pycore_pystate.h" // _PyInterpreterState_GET()
#include "pycore_unionobject.h" // _Py_union_type_or
#include "structmember.h"
#include <stddef.h>
struct validator {
@ -1651,6 +1874,7 @@ def write_source(mod, metadata, f, internal_h_file):
v = ChainOfVisitors(
SequenceConstructorVisitor(f),
PyTypesDeclareVisitor(f),
AnnotationsVisitor(f),
PyTypesVisitor(f),
Obj2ModPrototypeVisitor(f),
FunctionVisitor(f),