mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-23 11:54:39 +00:00
Auto generate visit_source_order
(#17180)
## Summary part of: #15655 I tried generating the source order function using code generation. I tried a simple approach, but it is not enough to generate all of them this way. There is one good thing, that most of the implementations are fine with this. We only have a few that are not. So one benefit of this PR could be it eliminates a lot of the code, hence changing the AST structure will only leave a few places to be fixed. The `source_order` field determines if a node requires a source order implementation. If it’s empty it means source order does not visit anything. Initially I didn’t want to repeat the field names. But I found two things: - `ExprIf` statement unlike other statements does not have the fields defined in source order. This and also some fields do not need to be included in the visit. So we just need a way to determine order, and determine presence. - Relying on the fields sounds more complicated to me. Maybe another solution is to add a new attribute `order` to each field? I'm open to suggestions. But anyway, except for the `ExprIf` we don't need to write the field names in order. Just knowing what fields must be visited are enough. Some nodes had a more complex visitor: `ExprCompare` required zipping two fields. `ExprBoolOp` required a match over the fields. `FstringValue` required a match, I created a new walk_ function that does the match. and used it in code generation. I don’t think this provides real value. Because I mostly moved the code from one file to another. I was tried it as an option. I prefer to leave it in the code as before. Some visitors visit a slice of items. Others visit a single element. I put a check on this in code generation to see if the field requires a for loop or not. I think better approach is to have a consistent style. So we can by default loop over any field that is a sequence. For field types `StringLiteralValue` and `BytesLiteralValue` the types are not a sequence in toml definition. But they implement `iter` so they are iterated over. So the code generation does not properly identify this. So in the code I'm checking for their types. ## Test Plan All the tests should pass without any changes. I checked the generated code to make sure it's the same as old code. I'm not sure if there's a test for the source order visitor.
This commit is contained in:
parent
bd89838212
commit
3ada36b766
5 changed files with 1048 additions and 886 deletions
|
@ -15,7 +15,7 @@ from typing import Any
|
|||
import tomllib
|
||||
|
||||
# Types that require `crate::`. We can slowly remove these types as we move them to generate scripts.
|
||||
types_requiring_create_prefix = [
|
||||
types_requiring_create_prefix = {
|
||||
"IpyEscapeKind",
|
||||
"ExprContext",
|
||||
"Identifier",
|
||||
|
@ -33,12 +33,11 @@ types_requiring_create_prefix = [
|
|||
"Decorator",
|
||||
"TypeParams",
|
||||
"Parameters",
|
||||
"Arguments",
|
||||
"ElifElseClause",
|
||||
"WithItem",
|
||||
"MatchCase",
|
||||
"Alias",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def rustfmt(code: str) -> str:
|
||||
|
@ -124,6 +123,8 @@ class Node:
|
|||
doc: str | None
|
||||
fields: list[Field] | None
|
||||
derives: list[str]
|
||||
custom_source_order: bool
|
||||
source_order: list[str] | None
|
||||
|
||||
def __init__(self, group: Group, node_name: str, node: dict[str, Any]) -> None:
|
||||
self.name = node_name
|
||||
|
@ -133,26 +134,82 @@ class Node:
|
|||
fields = node.get("fields")
|
||||
if fields is not None:
|
||||
self.fields = [Field(f) for f in fields]
|
||||
self.custom_source_order = node.get("custom_source_order", False)
|
||||
self.derives = node.get("derives", [])
|
||||
self.doc = node.get("doc")
|
||||
self.source_order = node.get("source_order")
|
||||
|
||||
def fields_in_source_order(self) -> list[Field]:
|
||||
if self.fields is None:
|
||||
return []
|
||||
if self.source_order is None:
|
||||
return list(filter(lambda x: not x.skip_source_order(), self.fields))
|
||||
|
||||
fields = []
|
||||
for field_name in self.source_order:
|
||||
field = None
|
||||
for field in self.fields:
|
||||
if field.skip_source_order():
|
||||
continue
|
||||
if field.name == field_name:
|
||||
field = field
|
||||
break
|
||||
fields.append(field)
|
||||
return fields
|
||||
|
||||
|
||||
@dataclass
|
||||
class Field:
|
||||
name: str
|
||||
ty: str
|
||||
_skip_visit: bool
|
||||
is_annotation: bool
|
||||
parsed_ty: FieldType
|
||||
|
||||
def __init__(self, field: dict[str, Any]) -> None:
|
||||
self.name = field["name"]
|
||||
self.ty = field["type"]
|
||||
self.parsed_ty = FieldType(self.ty)
|
||||
self._skip_visit = field.get("skip_visit", False)
|
||||
self.is_annotation = field.get("is_annotation", False)
|
||||
|
||||
def skip_source_order(self) -> bool:
|
||||
return self._skip_visit or self.parsed_ty.inner in [
|
||||
"str",
|
||||
"ExprContext",
|
||||
"Name",
|
||||
"u32",
|
||||
"bool",
|
||||
"Number",
|
||||
"IpyEscapeKind",
|
||||
]
|
||||
|
||||
|
||||
# Extracts the type argument from the given rust type with AST field type syntax.
|
||||
# Box<str> -> str
|
||||
# Box<Expr?> -> Expr
|
||||
# If the type does not have a type argument, it will return the string.
|
||||
# Does not support nested types
|
||||
def extract_type_argument(rust_type_str: str) -> str:
|
||||
rust_type_str = rust_type_str.replace("*", "")
|
||||
rust_type_str = rust_type_str.replace("?", "")
|
||||
rust_type_str = rust_type_str.replace("&", "")
|
||||
|
||||
open_bracket_index = rust_type_str.find("<")
|
||||
if open_bracket_index == -1:
|
||||
return rust_type_str
|
||||
close_bracket_index = rust_type_str.rfind(">")
|
||||
if close_bracket_index == -1 or close_bracket_index <= open_bracket_index:
|
||||
raise ValueError(f"Brackets are not balanced for type {rust_type_str}")
|
||||
inner_type = rust_type_str[open_bracket_index + 1 : close_bracket_index].strip()
|
||||
return inner_type
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldType:
|
||||
rule: str
|
||||
name: str
|
||||
inner: str
|
||||
seq: bool = False
|
||||
optional: bool = False
|
||||
slice_: bool = False
|
||||
|
@ -160,6 +217,7 @@ class FieldType:
|
|||
def __init__(self, rule: str) -> None:
|
||||
self.rule = rule
|
||||
self.name = ""
|
||||
self.inner = extract_type_argument(rule)
|
||||
|
||||
# The following cases are the limitations of this parser(and not used in the ast.toml):
|
||||
# * Rules that involve declaring a sequence with optional items e.g. Vec<Option<...>>
|
||||
|
@ -201,6 +259,7 @@ def write_preamble(out: list[str]) -> None:
|
|||
// Run `crates/ruff_python_ast/generate.py` to re-generate the file.
|
||||
|
||||
use crate::name::Name;
|
||||
use crate::visitor::source_order::SourceOrderVisitor;
|
||||
""")
|
||||
|
||||
|
||||
|
@ -703,6 +762,98 @@ def write_node(out: list[str], ast: Ast) -> None:
|
|||
out.append("")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Source order visitor
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisitorInfo:
|
||||
name: str
|
||||
accepts_sequence: bool = False
|
||||
|
||||
|
||||
# Map of AST node types to their corresponding visitor information
|
||||
type_to_visitor_function: dict[str, VisitorInfo] = {
|
||||
"Decorator": VisitorInfo("visit_decorator"),
|
||||
"Identifier": VisitorInfo("visit_identifier"),
|
||||
"crate::TypeParams": VisitorInfo("visit_type_params", True),
|
||||
"crate::Parameters": VisitorInfo("visit_parameters", True),
|
||||
"Expr": VisitorInfo("visit_expr"),
|
||||
"Stmt": VisitorInfo("visit_body", True),
|
||||
"Arguments": VisitorInfo("visit_arguments", True),
|
||||
"crate::Arguments": VisitorInfo("visit_arguments", True),
|
||||
"Operator": VisitorInfo("visit_operator"),
|
||||
"ElifElseClause": VisitorInfo("visit_elif_else_clause"),
|
||||
"WithItem": VisitorInfo("visit_with_item"),
|
||||
"MatchCase": VisitorInfo("visit_match_case"),
|
||||
"ExceptHandler": VisitorInfo("visit_except_handler"),
|
||||
"Alias": VisitorInfo("visit_alias"),
|
||||
"UnaryOp": VisitorInfo("visit_unary_op"),
|
||||
"DictItem": VisitorInfo("visit_dict_item"),
|
||||
"Comprehension": VisitorInfo("visit_comprehension"),
|
||||
"CmpOp": VisitorInfo("visit_cmp_op"),
|
||||
"FStringValue": VisitorInfo("visit_f_string_value"),
|
||||
"StringLiteralValue": VisitorInfo("visit_string_literal"),
|
||||
"BytesLiteralValue": VisitorInfo("visit_bytes_literal"),
|
||||
}
|
||||
annotation_visitor_function = VisitorInfo("visit_annotation")
|
||||
|
||||
|
||||
def write_source_order(out: list[str], ast: Ast) -> None:
|
||||
for group in ast.groups:
|
||||
for node in group.nodes:
|
||||
if node.fields is None or node.custom_source_order:
|
||||
continue
|
||||
name = node.name
|
||||
fields_list = ""
|
||||
body = ""
|
||||
|
||||
for field in node.fields:
|
||||
if field.skip_source_order():
|
||||
fields_list += f"{field.name}: _,\n"
|
||||
else:
|
||||
fields_list += f"{field.name},\n"
|
||||
fields_list += "range: _,\n"
|
||||
|
||||
for field in node.fields_in_source_order():
|
||||
visitor = type_to_visitor_function[field.parsed_ty.inner]
|
||||
if field.is_annotation:
|
||||
visitor = annotation_visitor_function
|
||||
|
||||
if field.parsed_ty.optional:
|
||||
body += f"""
|
||||
if let Some({field.name}) = {field.name} {{
|
||||
visitor.{visitor.name}({field.name});
|
||||
}}\n
|
||||
"""
|
||||
elif not visitor.accepts_sequence and field.parsed_ty.seq:
|
||||
body += f"""
|
||||
for elm in {field.name} {{
|
||||
visitor.{visitor.name}(elm);
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
body += f"visitor.{visitor.name}({field.name});\n"
|
||||
|
||||
visitor_arg_name = "visitor"
|
||||
if len(node.fields_in_source_order()) == 0:
|
||||
visitor_arg_name = "_"
|
||||
|
||||
out.append(f"""
|
||||
impl {name} {{
|
||||
pub(crate) fn visit_source_order<'a, V>(&'a self, {visitor_arg_name}: &mut V)
|
||||
where
|
||||
V: SourceOrderVisitor<'a> + ?Sized,
|
||||
{{
|
||||
let {name} {{
|
||||
{fields_list}
|
||||
}} = self;
|
||||
{body}
|
||||
}}
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Format and write output
|
||||
|
||||
|
@ -715,6 +866,7 @@ def generate(ast: Ast) -> list[str]:
|
|||
write_anynoderef(out, ast)
|
||||
write_nodekind(out, ast)
|
||||
write_node(out, ast)
|
||||
write_source_order(out, ast)
|
||||
return out
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue