Remove redundant type_to_visitor_function entries (#17564)

This commit is contained in:
Shaygan Hooshyari 2025-04-23 09:27:00 +02:00 committed by GitHub
parent f36262d970
commit 3fae176345
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -40,6 +40,23 @@ types_requiring_create_prefix = {
}
@dataclass
class VisitorInfo:
name: str
accepts_sequence: bool = False
# Map of AST node types to their corresponding visitor information.
# Only visitors that are different from the default `visit_*` method are included.
# These visitors either have a different name or accept a sequence of items.
type_to_visitor_function: dict[str, VisitorInfo] = {
"TypeParams": VisitorInfo("visit_type_params", True),
"Parameters": VisitorInfo("visit_parameters", True),
"Stmt": VisitorInfo("visit_body", True),
"Arguments": VisitorInfo("visit_arguments", True),
}
def rustfmt(code: str) -> str:
return check_output(["rustfmt", "--emit=stdout"], input=code, text=True)
@ -202,6 +219,7 @@ def extract_type_argument(rust_type_str: str) -> str:
if close_bracket_index == -1 or close_bracket_index <= open_bracket_index:
raise ValueError(f"Brackets are not balanced for type {rust_type_str}")
inner_type = rust_type_str[open_bracket_index + 1 : close_bracket_index].strip()
inner_type = inner_type.replace("crate::", "")
return inner_type
@ -766,39 +784,6 @@ def write_node(out: list[str], ast: Ast) -> None:
# Source order visitor
@dataclass
class VisitorInfo:
name: str
accepts_sequence: bool = False
# Map of AST node types to their corresponding visitor information
type_to_visitor_function: dict[str, VisitorInfo] = {
"Decorator": VisitorInfo("visit_decorator"),
"Identifier": VisitorInfo("visit_identifier"),
"crate::TypeParams": VisitorInfo("visit_type_params", True),
"crate::Parameters": VisitorInfo("visit_parameters", True),
"Expr": VisitorInfo("visit_expr"),
"Stmt": VisitorInfo("visit_body", True),
"Arguments": VisitorInfo("visit_arguments", True),
"crate::Arguments": VisitorInfo("visit_arguments", True),
"Operator": VisitorInfo("visit_operator"),
"ElifElseClause": VisitorInfo("visit_elif_else_clause"),
"WithItem": VisitorInfo("visit_with_item"),
"MatchCase": VisitorInfo("visit_match_case"),
"ExceptHandler": VisitorInfo("visit_except_handler"),
"Alias": VisitorInfo("visit_alias"),
"UnaryOp": VisitorInfo("visit_unary_op"),
"DictItem": VisitorInfo("visit_dict_item"),
"Comprehension": VisitorInfo("visit_comprehension"),
"CmpOp": VisitorInfo("visit_cmp_op"),
"FStringValue": VisitorInfo("visit_f_string_value"),
"StringLiteralValue": VisitorInfo("visit_string_literal"),
"BytesLiteralValue": VisitorInfo("visit_bytes_literal"),
}
annotation_visitor_function = VisitorInfo("visit_annotation")
def write_source_order(out: list[str], ast: Ast) -> None:
for group in ast.groups:
for node in group.nodes:
@ -816,24 +801,33 @@ def write_source_order(out: list[str], ast: Ast) -> None:
fields_list += "range: _,\n"
for field in node.fields_in_source_order():
visitor = type_to_visitor_function[field.parsed_ty.inner]
visitor_name = (
type_to_visitor_function.get(
field.parsed_ty.inner, VisitorInfo("")
).name
or f"visit_{to_snake_case(field.parsed_ty.inner)}"
)
visits_sequence = type_to_visitor_function.get(
field.parsed_ty.inner, VisitorInfo("")
).accepts_sequence
if field.is_annotation:
visitor = annotation_visitor_function
visitor_name = "visit_annotation"
if field.parsed_ty.optional:
body += f"""
if let Some({field.name}) = {field.name} {{
visitor.{visitor.name}({field.name});
visitor.{visitor_name}({field.name});
}}\n
"""
elif not visitor.accepts_sequence and field.parsed_ty.seq:
elif not visits_sequence and field.parsed_ty.seq:
body += f"""
for elm in {field.name} {{
visitor.{visitor.name}(elm);
visitor.{visitor_name}(elm);
}}
"""
else:
body += f"visitor.{visitor.name}({field.name});\n"
body += f"visitor.{visitor_name}({field.name});\n"
visitor_arg_name = "visitor"
if len(node.fields_in_source_order()) == 0: