Separate grouped and ungrouped nodes more clearly in AST generator (#15646)

This is a minor cleanup to the AST generation script to make a clearer
separation between nodes that do appear in a group enum, and those that
don't. There are some types and methods that we create for every syntax
node, and others that refer to the group that the syntax node belongs
to, and which therefore don't make sense for ungrouped nodes. This new
separation makes it clearer which category each definition is in, since
you're either inside of a `for group in ast.groups` loop, or a `for node
in ast.all_nodes` loop.
This commit is contained in:
Douglas Creager 2025-01-21 13:37:18 -05:00 committed by GitHub
parent fce4adfd41
commit fa546b20a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 254 additions and 248 deletions

View file

@ -28,17 +28,41 @@ def to_snake_case(node: str) -> str:
# Read AST description
def load_ast(root: Path) -> list[Group]:
def load_ast(root: Path) -> Ast:
ast_path = root.joinpath("crates", "ruff_python_ast", "ast.toml")
with ast_path.open("rb") as ast_file:
ast = tomllib.load(ast_file)
return [Group(group_name, group) for group_name, group in ast.items()]
return Ast(ast)
# ------------------------------------------------------------------------------
# Preprocess
@dataclass
class Ast:
"""
The parsed representation of the `ast.toml` file. Defines all of the Python
AST syntax nodes, and which groups (`Stmt`, `Expr`, etc.) they belong to.
"""
groups: list[Group]
ungrouped_nodes: list[Node]
all_nodes: list[Node]
def __init__(self, ast: dict[str, Any]) -> None:
self.groups = []
self.ungrouped_nodes = []
self.all_nodes = []
for group_name, group in ast.items():
group = Group(group_name, group)
self.all_nodes.extend(group.nodes)
if group_name == "ungrouped":
self.ungrouped_nodes = group.nodes
else:
self.groups.append(group)
@dataclass
class Group:
name: str
@ -89,7 +113,7 @@ def write_preamble(out: list[str]) -> None:
# Owned enum
def write_owned_enum(out: list[str], groups: list[Group]) -> None:
def write_owned_enum(out: list[str], ast: Ast) -> None:
"""
Create an enum for each group that contains an owned copy of a syntax node.
@ -112,10 +136,7 @@ def write_owned_enum(out: list[str], groups: list[Group]) -> None:
`is_type_var` method will be named `is_type_var_type_param`.
"""
for group in groups:
if group.name == "ungrouped":
continue
for group in ast.groups:
out.append("")
if group.rustdoc is not None:
out.append(group.rustdoc)
@ -150,19 +171,16 @@ def write_owned_enum(out: list[str], groups: list[Group]) -> None:
}
""")
for group in groups:
for node in group.nodes:
out.append(f"""
for node in ast.all_nodes:
out.append(f"""
impl ruff_text_size::Ranged for {node.ty} {{
fn range(&self) -> ruff_text_size::TextRange {{
self.range
}}
}}
""")
""")
for group in groups:
if group.name == "ungrouped":
continue
for group in ast.groups:
out.append(f"""
impl {group.owned_enum_ty} {{
#[allow(unused)]
@ -187,7 +205,7 @@ def write_owned_enum(out: list[str], groups: list[Group]) -> None:
# Ref enum
def write_ref_enum(out: list[str], groups: list[Group]) -> None:
def write_ref_enum(out: list[str], ast: Ast) -> None:
"""
Create an enum for each group that contains a reference to a syntax node.
@ -211,10 +229,7 @@ def write_ref_enum(out: list[str], groups: list[Group]) -> None:
method will be named `is_type_var_type_param`.
"""
for group in groups:
if group.name == "ungrouped":
continue
for group in ast.groups:
out.append("")
if group.rustdoc is not None:
out.append(group.rustdoc)
@ -269,7 +284,7 @@ def write_ref_enum(out: list[str], groups: list[Group]) -> None:
# AnyNodeRef
def write_anynoderef(out: list[str], groups: list[Group]) -> None:
def write_anynoderef(out: list[str], ast: Ast) -> None:
"""
Create the AnyNodeRef type.
@ -295,62 +310,59 @@ def write_anynoderef(out: list[str], groups: list[Group]) -> None:
#[derive(Copy, Clone, Debug, is_macro::Is, PartialEq)]
pub enum AnyNodeRef<'a> {
""")
for group in groups:
for node in group.nodes:
out.append(f"""{node.name}(&'a {node.ty}),""")
for node in ast.all_nodes:
out.append(f"""{node.name}(&'a {node.ty}),""")
out.append("""
}
""")
for group in groups:
if group.name != "ungrouped":
out.append(f"""
for group in ast.groups:
out.append(f"""
impl<'a> From<&'a {group.owned_enum_ty}> for AnyNodeRef<'a> {{
fn from(node: &'a {group.owned_enum_ty}) -> AnyNodeRef<'a> {{
match node {{
""")
for node in group.nodes:
out.append(
f"{group.owned_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
)
out.append("""
""")
for node in group.nodes:
out.append(
f"{group.owned_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
)
out.append("""
}
}
}
""")
""")
out.append(f"""
out.append(f"""
impl<'a> From<{group.ref_enum_ty}<'a>> for AnyNodeRef<'a> {{
fn from(node: {group.ref_enum_ty}<'a>) -> AnyNodeRef<'a> {{
match node {{
""")
for node in group.nodes:
out.append(
f"{group.ref_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
)
out.append("""
""")
for node in group.nodes:
out.append(
f"{group.ref_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
)
out.append("""
}
}
}
""")
""")
for node in group.nodes:
out.append(f"""
for node in ast.all_nodes:
out.append(f"""
impl<'a> From<&'a {node.ty}> for AnyNodeRef<'a> {{
fn from(node: &'a {node.ty}) -> AnyNodeRef<'a> {{
AnyNodeRef::{node.name}(node)
}}
}}
""")
""")
out.append("""
impl ruff_text_size::Ranged for AnyNodeRef<'_> {
fn range(&self) -> ruff_text_size::TextRange {
match self {
""")
for group in groups:
for node in group.nodes:
out.append(f"""AnyNodeRef::{node.name}(node) => node.range(),""")
for node in ast.all_nodes:
out.append(f"""AnyNodeRef::{node.name}(node) => node.range(),""")
out.append("""
}
}
@ -362,11 +374,10 @@ def write_anynoderef(out: list[str], groups: list[Group]) -> None:
pub fn as_ptr(&self) -> std::ptr::NonNull<()> {
match self {
""")
for group in groups:
for node in group.nodes:
out.append(
f"AnyNodeRef::{node.name}(node) => std::ptr::NonNull::from(*node).cast(),"
)
for node in ast.all_nodes:
out.append(
f"AnyNodeRef::{node.name}(node) => std::ptr::NonNull::from(*node).cast(),"
)
out.append("""
}
}
@ -382,20 +393,17 @@ def write_anynoderef(out: list[str], groups: list[Group]) -> None:
{
match self {
""")
for group in groups:
for node in group.nodes:
out.append(
f"AnyNodeRef::{node.name}(node) => node.visit_source_order(visitor),"
)
for node in ast.all_nodes:
out.append(
f"AnyNodeRef::{node.name}(node) => node.visit_source_order(visitor),"
)
out.append("""
}
}
}
""")
for group in groups:
if group.name == "ungrouped":
continue
for group in ast.groups:
out.append(f"""
impl AnyNodeRef<'_> {{
pub const fn is_{group.anynode_is_label}(self) -> bool {{
@ -416,7 +424,7 @@ def write_anynoderef(out: list[str], groups: list[Group]) -> None:
# NodeKind
def write_nodekind(out: list[str], groups: list[Group]) -> None:
def write_nodekind(out: list[str], ast: Ast) -> None:
"""
Create the NodeKind type.
@ -437,9 +445,8 @@ def write_nodekind(out: list[str], groups: list[Group]) -> None:
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum NodeKind {
""")
for group in groups:
for node in group.nodes:
out.append(f"""{node.name},""")
for node in ast.all_nodes:
out.append(f"""{node.name},""")
out.append("""
}
""")
@ -449,9 +456,8 @@ def write_nodekind(out: list[str], groups: list[Group]) -> None:
pub const fn kind(self) -> NodeKind {
match self {
""")
for group in groups:
for node in group.nodes:
out.append(f"""AnyNodeRef::{node.name}(_) => NodeKind::{node.name},""")
for node in ast.all_nodes:
out.append(f"""AnyNodeRef::{node.name}(_) => NodeKind::{node.name},""")
out.append("""
}
}
@ -463,13 +469,13 @@ def write_nodekind(out: list[str], groups: list[Group]) -> None:
# Format and write output
def generate(groups: list[Group]) -> list[str]:
def generate(ast: Ast) -> list[str]:
out = []
write_preamble(out)
write_owned_enum(out, groups)
write_ref_enum(out, groups)
write_anynoderef(out, groups)
write_nodekind(out, groups)
write_owned_enum(out, ast)
write_ref_enum(out, ast)
write_anynoderef(out, ast)
write_nodekind(out, ast)
return out
@ -486,8 +492,8 @@ def main() -> None:
root = Path(
check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip()
)
groups = load_ast(root)
out = generate(groups)
ast = load_ast(root)
out = generate(ast)
write_output(root, out)