#!/usr/bin/python # /// script # requires-python = ">=3.11" # dependencies = [] # /// from __future__ import annotations import re from dataclasses import dataclass from pathlib import Path from subprocess import check_output from typing import Any import tomllib # Types that require `crate::`. We can slowly remove these types as we move them to generate scripts. types_requiring_crate_prefix = { "IpyEscapeKind", "ExprContext", "Identifier", "Number", "BytesLiteralValue", "StringLiteralValue", "FStringValue", "TStringValue", "Arguments", "CmpOp", "Comprehension", "DictItem", "UnaryOp", "BoolOp", "Operator", "Decorator", "TypeParams", "Parameters", "ElifElseClause", "WithItem", "MatchCase", "Alias", } @dataclass class VisitorInfo: name: str accepts_sequence: bool = False # Map of AST node types to their corresponding visitor information. # Only visitors that are different from the default `visit_*` method are included. # These visitors either have a different name or accept a sequence of items. type_to_visitor_function: dict[str, VisitorInfo] = { "TypeParams": VisitorInfo("visit_type_params", True), "Parameters": VisitorInfo("visit_parameters", True), "Stmt": VisitorInfo("visit_body", True), "Arguments": VisitorInfo("visit_arguments", True), } def rustfmt(code: str) -> str: return check_output(["rustfmt", "--emit=stdout"], input=code, text=True) def to_snake_case(node: str) -> str: """Converts CamelCase to snake_case""" return re.sub("([A-Z])", r"_\1", node).lower().lstrip("_") def write_rustdoc(out: list[str], doc: str) -> None: for line in doc.split("\n"): out.append(f"/// {line}") # ------------------------------------------------------------------------------ # Read AST description def load_ast(root: Path) -> Ast: ast_path = root.joinpath("crates", "ruff_python_ast", "ast.toml") with ast_path.open("rb") as ast_file: ast = tomllib.load(ast_file) return Ast(ast) # ------------------------------------------------------------------------------ # Preprocess @dataclass class Ast: """ The parsed representation of the `ast.toml` file. Defines all of the Python AST syntax nodes, and which groups (`Stmt`, `Expr`, etc.) they belong to. """ groups: list[Group] ungrouped_nodes: list[Node] all_nodes: list[Node] def __init__(self, ast: dict[str, Any]) -> None: self.groups = [] self.ungrouped_nodes = [] self.all_nodes = [] for group_name, group in ast.items(): group = Group(group_name, group) self.all_nodes.extend(group.nodes) if group_name == "ungrouped": self.ungrouped_nodes = group.nodes else: self.groups.append(group) @dataclass class Group: name: str nodes: list[Node] owned_enum_ty: str add_suffix_to_is_methods: bool anynode_is_label: str doc: str | None def __init__(self, group_name: str, group: dict[str, Any]) -> None: self.name = group_name self.owned_enum_ty = group_name self.ref_enum_ty = group_name + "Ref" self.add_suffix_to_is_methods = group.get("add_suffix_to_is_methods", False) self.anynode_is_label = group.get("anynode_is_label", to_snake_case(group_name)) self.doc = group.get("doc") self.nodes = [ Node(self, node_name, node) for node_name, node in group["nodes"].items() ] @dataclass class Node: name: str variant: str ty: str doc: str | None fields: list[Field] | None derives: list[str] custom_source_order: bool source_order: list[str] | None def __init__(self, group: Group, node_name: str, node: dict[str, Any]) -> None: self.name = node_name self.variant = node.get("variant", node_name.removeprefix(group.name)) self.ty = f"crate::{node_name}" self.fields = None fields = node.get("fields") if fields is not None: self.fields = [Field(f) for f in fields] self.custom_source_order = node.get("custom_source_order", False) self.derives = node.get("derives", []) self.doc = node.get("doc") self.source_order = node.get("source_order") def fields_in_source_order(self) -> list[Field]: if self.fields is None: return [] if self.source_order is None: return list(filter(lambda x: not x.skip_source_order(), self.fields)) fields = [] for field_name in self.source_order: field = None for field in self.fields: if field.skip_source_order(): continue if field.name == field_name: field = field break fields.append(field) return fields @dataclass class Field: name: str ty: str _skip_visit: bool is_annotation: bool parsed_ty: FieldType def __init__(self, field: dict[str, Any]) -> None: self.name = field["name"] self.ty = field["type"] self.parsed_ty = FieldType(self.ty) self._skip_visit = field.get("skip_visit", False) self.is_annotation = field.get("is_annotation", False) def skip_source_order(self) -> bool: return self._skip_visit or self.parsed_ty.inner in [ "str", "ExprContext", "Name", "u32", "bool", "Number", "IpyEscapeKind", ] # Extracts the type argument from the given rust type with AST field type syntax. # Box -> str # Box -> Expr # If the type does not have a type argument, it will return the string. # Does not support nested types def extract_type_argument(rust_type_str: str) -> str: rust_type_str = rust_type_str.replace("*", "") rust_type_str = rust_type_str.replace("?", "") rust_type_str = rust_type_str.replace("&", "") open_bracket_index = rust_type_str.find("<") if open_bracket_index == -1: return rust_type_str close_bracket_index = rust_type_str.rfind(">") if close_bracket_index == -1 or close_bracket_index <= open_bracket_index: raise ValueError(f"Brackets are not balanced for type {rust_type_str}") inner_type = rust_type_str[open_bracket_index + 1 : close_bracket_index].strip() inner_type = inner_type.replace("crate::", "") return inner_type @dataclass class FieldType: rule: str name: str inner: str seq: bool = False optional: bool = False slice_: bool = False def __init__(self, rule: str) -> None: self.rule = rule self.name = "" self.inner = extract_type_argument(rule) # The following cases are the limitations of this parser(and not used in the ast.toml): # * Rules that involve declaring a sequence with optional items e.g. Vec> last_pos = len(rule) - 1 for i, ch in enumerate(rule): if ch == "?": if i == last_pos: self.optional = True else: raise ValueError(f"`?` must be at the end: {rule}") elif ch == "*": if self.slice_: # The * after & is a slice continue if i == last_pos: self.seq = True else: raise ValueError(f"`*` must be at the end: {rule}") elif ch == "&": if i == 0 and rule.endswith("*"): self.slice_ = True else: raise ValueError( f"`&` must be at the start and end with `*`: {rule}" ) else: self.name += ch if self.optional and (self.seq or self.slice_): raise ValueError(f"optional field cannot be sequence or slice: {rule}") # ------------------------------------------------------------------------------ # Preamble def write_preamble(out: list[str]) -> None: out.append(""" // This is a generated file. Don't modify it by hand! // Run `crates/ruff_python_ast/generate.py` to re-generate the file. use crate::name::Name; use crate::visitor::source_order::SourceOrderVisitor; """) # ------------------------------------------------------------------------------ # Owned enum def write_owned_enum(out: list[str], ast: Ast) -> None: """ Create an enum for each group that contains an owned copy of a syntax node. ```rust pub enum TypeParam { TypeVar(TypeParamTypeVar), TypeVarTuple(TypeParamTypeVarTuple), ... } ``` Also creates: - `impl Ranged for TypeParam` - `impl HasNodeIndex for TypeParam` - `TypeParam::visit_source_order` - `impl From for TypeParam` - `impl Ranged for TypeParamTypeVar` - `impl HasNodeIndex for TypeParamTypeVar` - `fn TypeParam::is_type_var() -> bool` If the `add_suffix_to_is_methods` group option is true, then the `is_type_var` method will be named `is_type_var_type_param`. """ for group in ast.groups: out.append("") if group.doc is not None: write_rustdoc(out, group.doc) out.append("#[derive(Clone, Debug, PartialEq)]") out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]') out.append(f"pub enum {group.owned_enum_ty} {{") for node in group.nodes: out.append(f"{node.variant}({node.ty}),") out.append("}") for node in group.nodes: out.append(f""" impl From<{node.ty}> for {group.owned_enum_ty} {{ fn from(node: {node.ty}) -> Self {{ Self::{node.variant}(node) }} }} """) out.append(f""" impl ruff_text_size::Ranged for {group.owned_enum_ty} {{ fn range(&self) -> ruff_text_size::TextRange {{ match self {{ """) for node in group.nodes: out.append(f"Self::{node.variant}(node) => node.range(),") out.append(""" } } } """) out.append(f""" impl crate::HasNodeIndex for {group.owned_enum_ty} {{ fn node_index(&self) -> &crate::AtomicNodeIndex {{ match self {{ """) for node in group.nodes: out.append(f"Self::{node.variant}(node) => node.node_index(),") out.append(""" } } } """) out.append( "#[allow(dead_code, clippy::match_wildcard_for_single_variants)]" ) # Not all is_methods are used out.append(f"impl {group.name} {{") for node in group.nodes: is_name = to_snake_case(node.variant) variant_name = node.variant match_arm = f"Self::{variant_name}" if group.add_suffix_to_is_methods: is_name = to_snake_case(node.variant + group.name) if len(group.nodes) > 1: out.append(f""" #[inline] pub const fn is_{is_name}(&self) -> bool {{ matches!(self, {match_arm}(_)) }} #[inline] pub fn {is_name}(self) -> Option<{node.ty}> {{ match self {{ {match_arm}(val) => Some(val), _ => None, }} }} #[inline] pub fn expect_{is_name}(self) -> {node.ty} {{ match self {{ {match_arm}(val) => val, _ => panic!("called expect on {{self:?}}"), }} }} #[inline] pub fn as_{is_name}_mut(&mut self) -> Option<&mut {node.ty}> {{ match self {{ {match_arm}(val) => Some(val), _ => None, }} }} #[inline] pub fn as_{is_name}(&self) -> Option<&{node.ty}> {{ match self {{ {match_arm}(val) => Some(val), _ => None, }} }} """) elif len(group.nodes) == 1: out.append(f""" #[inline] pub const fn is_{is_name}(&self) -> bool {{ matches!(self, {match_arm}(_)) }} #[inline] pub fn {is_name}(self) -> Option<{node.ty}> {{ match self {{ {match_arm}(val) => Some(val), }} }} #[inline] pub fn expect_{is_name}(self) -> {node.ty} {{ match self {{ {match_arm}(val) => val, }} }} #[inline] pub fn as_{is_name}_mut(&mut self) -> Option<&mut {node.ty}> {{ match self {{ {match_arm}(val) => Some(val), }} }} #[inline] pub fn as_{is_name}(&self) -> Option<&{node.ty}> {{ match self {{ {match_arm}(val) => Some(val), }} }} """) out.append("}") for node in ast.all_nodes: out.append(f""" impl ruff_text_size::Ranged for {node.ty} {{ fn range(&self) -> ruff_text_size::TextRange {{ self.range }} }} """) for node in ast.all_nodes: out.append(f""" impl crate::HasNodeIndex for {node.ty} {{ fn node_index(&self) -> &crate::AtomicNodeIndex {{ &self.node_index }} }} """) for group in ast.groups: out.append(f""" impl {group.owned_enum_ty} {{ #[allow(unused)] pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V) where V: crate::visitor::source_order::SourceOrderVisitor<'a> + ?Sized, {{ match self {{ """) for node in group.nodes: out.append( f"{group.owned_enum_ty}::{node.variant}(node) => node.visit_source_order(visitor)," ) out.append(""" } } } """) # ------------------------------------------------------------------------------ # Ref enum def write_ref_enum(out: list[str], ast: Ast) -> None: """ Create an enum for each group that contains a reference to a syntax node. ```rust pub enum TypeParamRef<'a> { TypeVar(&'a TypeParamTypeVar), TypeVarTuple(&'a TypeParamTypeVarTuple), ... } ``` Also creates: - `impl<'a> From<&'a TypeParam> for TypeParamRef<'a>` - `impl<'a> From<&'a TypeParamTypeVar> for TypeParamRef<'a>` - `impl Ranged for TypeParamRef<'_>` - `impl HasNodeIndex for TypeParamRef<'_>` - `fn TypeParamRef::is_type_var() -> bool` The name of each variant can be customized via the `variant` node option. If the `add_suffix_to_is_methods` group option is true, then the `is_type_var` method will be named `is_type_var_type_param`. """ for group in ast.groups: out.append("") if group.doc is not None: write_rustdoc(out, group.doc) out.append("""#[derive(Clone, Copy, Debug, PartialEq, is_macro::Is)]""") out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]') out.append(f"""pub enum {group.ref_enum_ty}<'a> {{""") for node in group.nodes: if group.add_suffix_to_is_methods: is_name = to_snake_case(node.variant + group.name) out.append(f'#[is(name = "{is_name}")]') out.append(f"""{node.variant}(&'a {node.ty}),""") out.append("}") out.append(f""" impl<'a> From<&'a {group.owned_enum_ty}> for {group.ref_enum_ty}<'a> {{ fn from(node: &'a {group.owned_enum_ty}) -> Self {{ match node {{ """) for node in group.nodes: out.append( f"{group.owned_enum_ty}::{node.variant}(node) => {group.ref_enum_ty}::{node.variant}(node)," ) out.append(""" } } } """) for node in group.nodes: out.append(f""" impl<'a> From<&'a {node.ty}> for {group.ref_enum_ty}<'a> {{ fn from(node: &'a {node.ty}) -> Self {{ Self::{node.variant}(node) }} }} """) out.append(f""" impl ruff_text_size::Ranged for {group.ref_enum_ty}<'_> {{ fn range(&self) -> ruff_text_size::TextRange {{ match self {{ """) for node in group.nodes: out.append(f"Self::{node.variant}(node) => node.range(),") out.append(""" } } } """) out.append(f""" impl crate::HasNodeIndex for {group.ref_enum_ty}<'_> {{ fn node_index(&self) -> &crate::AtomicNodeIndex {{ match self {{ """) for node in group.nodes: out.append(f"Self::{node.variant}(node) => node.node_index(),") out.append(""" } } } """) # ------------------------------------------------------------------------------ # AnyNodeRef def write_anynoderef(out: list[str], ast: Ast) -> None: """ Create the AnyNodeRef type. ```rust pub enum AnyNodeRef<'a> { ... TypeParamTypeVar(&'a TypeParamTypeVar), TypeParamTypeVarTuple(&'a TypeParamTypeVarTuple), ... } ``` Also creates: - `impl<'a> From<&'a TypeParam> for AnyNodeRef<'a>` - `impl<'a> From> for AnyNodeRef<'a>` - `impl<'a> From<&'a TypeParamTypeVarTuple> for AnyNodeRef<'a>` - `impl Ranged for AnyNodeRef<'_>` - `impl HasNodeIndex for AnyNodeRef<'_>` - `fn AnyNodeRef::as_ptr(&self) -> std::ptr::NonNull<()>` - `fn AnyNodeRef::visit_source_order(self, visitor &mut impl SourceOrderVisitor)` """ out.append(""" /// A flattened enumeration of all AST nodes. #[derive(Copy, Clone, Debug, is_macro::Is, PartialEq)] #[cfg_attr(feature = "get-size", derive(get_size2::GetSize))] pub enum AnyNodeRef<'a> { """) for node in ast.all_nodes: out.append(f"""{node.name}(&'a {node.ty}),""") out.append(""" } """) for group in ast.groups: out.append(f""" impl<'a> From<&'a {group.owned_enum_ty}> for AnyNodeRef<'a> {{ fn from(node: &'a {group.owned_enum_ty}) -> AnyNodeRef<'a> {{ match node {{ """) for node in group.nodes: out.append( f"{group.owned_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node)," ) out.append(""" } } } """) out.append(f""" impl<'a> From<{group.ref_enum_ty}<'a>> for AnyNodeRef<'a> {{ fn from(node: {group.ref_enum_ty}<'a>) -> AnyNodeRef<'a> {{ match node {{ """) for node in group.nodes: out.append( f"{group.ref_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node)," ) out.append(""" } } } """) # `as_*` methods to convert from `AnyNodeRef` to e.g. `ExprRef` out.append(f""" impl<'a> AnyNodeRef<'a> {{ pub fn as_{to_snake_case(group.ref_enum_ty)}(self) -> Option<{group.ref_enum_ty}<'a>> {{ match self {{ """) for node in group.nodes: out.append( f"Self::{node.name}(node) => Some({group.ref_enum_ty}::{node.variant}(node))," ) out.append(""" _ => None, } } } """) for node in ast.all_nodes: out.append(f""" impl<'a> From<&'a {node.ty}> for AnyNodeRef<'a> {{ fn from(node: &'a {node.ty}) -> AnyNodeRef<'a> {{ AnyNodeRef::{node.name}(node) }} }} """) out.append(""" impl ruff_text_size::Ranged for AnyNodeRef<'_> { fn range(&self) -> ruff_text_size::TextRange { match self { """) for node in ast.all_nodes: out.append(f"""AnyNodeRef::{node.name}(node) => node.range(),""") out.append(""" } } } """) out.append(""" impl crate::HasNodeIndex for AnyNodeRef<'_> { fn node_index(&self) -> &crate::AtomicNodeIndex { match self { """) for node in ast.all_nodes: out.append(f"""AnyNodeRef::{node.name}(node) => node.node_index(),""") out.append(""" } } } """) out.append(""" impl AnyNodeRef<'_> { pub fn as_ptr(&self) -> std::ptr::NonNull<()> { match self { """) for node in ast.all_nodes: out.append( f"AnyNodeRef::{node.name}(node) => std::ptr::NonNull::from(*node).cast()," ) out.append(""" } } } """) out.append(""" impl<'a> AnyNodeRef<'a> { pub fn visit_source_order<'b, V>(self, visitor: &mut V) where V: crate::visitor::source_order::SourceOrderVisitor<'b> + ?Sized, 'a: 'b, { match self { """) for node in ast.all_nodes: out.append( f"AnyNodeRef::{node.name}(node) => node.visit_source_order(visitor)," ) out.append(""" } } } """) for group in ast.groups: out.append(f""" impl AnyNodeRef<'_> {{ pub const fn is_{group.anynode_is_label}(self) -> bool {{ matches!(self, """) for i, node in enumerate(group.nodes): if i > 0: out.append("|") out.append(f"""AnyNodeRef::{node.name}(_)""") out.append(""" ) } } """) # ------------------------------------------------------------------------------ # AnyRootNodeRef def write_root_anynoderef(out: list[str], ast: Ast) -> None: """ Create the AnyRootNodeRef type. ```rust pub enum AnyRootNodeRef<'a> { ... TypeParam(&'a TypeParam), ... } ``` Also creates: - `impl<'a> From<&'a TypeParam> for AnyRootNodeRef<'a>` - `impl<'a> TryFrom> for &'a TypeParam` - `impl<'a> TryFrom> for &'a TypeParamVarTuple` - `impl Ranged for AnyRootNodeRef<'_>` - `impl HasNodeIndex for AnyRootNodeRef<'_>` - `fn AnyRootNodeRef::visit_source_order(self, visitor &mut impl SourceOrderVisitor)` """ out.append(""" /// An enumeration of all AST nodes. /// /// Unlike `AnyNodeRef`, this type does not flatten nested enums, so its variants only /// consist of the "root" AST node types. This is useful as it exposes references to the /// original enums, not just references to their inner values. /// /// For example, `AnyRootNodeRef::Mod` contains a reference to the `Mod` enum, while /// `AnyNodeRef` has top-level `AnyNodeRef::ModModule` and `AnyNodeRef::ModExpression` /// variants. #[derive(Copy, Clone, Debug, PartialEq)] #[cfg_attr(feature = "get-size", derive(get_size2::GetSize))] pub enum AnyRootNodeRef<'a> { """) for group in ast.groups: out.append(f"""{group.name}(&'a {group.owned_enum_ty}),""") for node in ast.ungrouped_nodes: out.append(f"""{node.name}(&'a {node.ty}),""") out.append(""" } """) for group in ast.groups: out.append(f""" impl<'a> From<&'a {group.owned_enum_ty}> for AnyRootNodeRef<'a> {{ fn from(node: &'a {group.owned_enum_ty}) -> AnyRootNodeRef<'a> {{ AnyRootNodeRef::{group.name}(node) }} }} """) out.append(f""" impl<'a> TryFrom> for &'a {group.owned_enum_ty} {{ type Error = (); fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {group.owned_enum_ty}, ()> {{ match node {{ AnyRootNodeRef::{group.name}(node) => Ok(node), _ => Err(()) }} }} }} """) for node in group.nodes: out.append(f""" impl<'a> TryFrom> for &'a {node.ty} {{ type Error = (); fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {node.ty}, ()> {{ match node {{ AnyRootNodeRef::{group.name}({group.owned_enum_ty}::{node.variant}(node)) => Ok(node), _ => Err(()) }} }} }} """) for node in ast.ungrouped_nodes: out.append(f""" impl<'a> From<&'a {node.ty}> for AnyRootNodeRef<'a> {{ fn from(node: &'a {node.ty}) -> AnyRootNodeRef<'a> {{ AnyRootNodeRef::{node.name}(node) }} }} """) out.append(f""" impl<'a> TryFrom> for &'a {node.ty} {{ type Error = (); fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {node.ty}, ()> {{ match node {{ AnyRootNodeRef::{node.name}(node) => Ok(node), _ => Err(()) }} }} }} """) out.append(""" impl ruff_text_size::Ranged for AnyRootNodeRef<'_> { fn range(&self) -> ruff_text_size::TextRange { match self { """) for group in ast.groups: out.append(f"""AnyRootNodeRef::{group.name}(node) => node.range(),""") for node in ast.ungrouped_nodes: out.append(f"""AnyRootNodeRef::{node.name}(node) => node.range(),""") out.append(""" } } } """) out.append(""" impl crate::HasNodeIndex for AnyRootNodeRef<'_> { fn node_index(&self) -> &crate::AtomicNodeIndex { match self { """) for group in ast.groups: out.append(f"""AnyRootNodeRef::{group.name}(node) => node.node_index(),""") for node in ast.ungrouped_nodes: out.append(f"""AnyRootNodeRef::{node.name}(node) => node.node_index(),""") out.append(""" } } } """) out.append(""" impl<'a> AnyRootNodeRef<'a> { pub fn visit_source_order<'b, V>(self, visitor: &mut V) where V: crate::visitor::source_order::SourceOrderVisitor<'b> + ?Sized, 'a: 'b, { match self { """) for group in ast.groups: out.append( f"""AnyRootNodeRef::{group.name}(node) => node.visit_source_order(visitor),""" ) for node in ast.ungrouped_nodes: out.append( f"""AnyRootNodeRef::{node.name}(node) => node.visit_source_order(visitor),""" ) out.append(""" } } } """) # ------------------------------------------------------------------------------ # NodeKind def write_nodekind(out: list[str], ast: Ast) -> None: """ Create the NodeKind type. ```rust pub enum NodeKind { ... TypeParamTypeVar, TypeParamTypeVarTuple, ... } Also creates: - `fn AnyNodeRef::kind(self) -> NodeKind` ``` """ out.append(""" #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub enum NodeKind { """) for node in ast.all_nodes: out.append(f"""{node.name},""") out.append(""" } """) out.append(""" impl AnyNodeRef<'_> { pub const fn kind(self) -> NodeKind { match self { """) for node in ast.all_nodes: out.append(f"""AnyNodeRef::{node.name}(_) => NodeKind::{node.name},""") out.append(""" } } } """) # ------------------------------------------------------------------------------ # Node structs def write_node(out: list[str], ast: Ast) -> None: group_names = [group.name for group in ast.groups] for group in ast.groups: for node in group.nodes: if node.fields is None: continue if node.doc is not None: write_rustdoc(out, node.doc) out.append( "#[derive(Clone, Debug, PartialEq" + "".join(f", {derive}" for derive in node.derives) + ")]" ) out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]') name = node.name out.append(f"pub struct {name} {{") out.append("pub node_index: crate::AtomicNodeIndex,") out.append("pub range: ruff_text_size::TextRange,") for field in node.fields: field_str = f"pub {field.name}: " ty = field.parsed_ty rust_ty = f"{field.parsed_ty.name}" if ty.name in types_requiring_crate_prefix: rust_ty = f"crate::{rust_ty}" if ty.slice_: rust_ty = f"[{rust_ty}]" if (ty.name in group_names or ty.slice_) and ty.seq is False: rust_ty = f"Box<{rust_ty}>" if ty.seq: rust_ty = f"Vec<{rust_ty}>" elif ty.optional: rust_ty = f"Option<{rust_ty}>" field_str += rust_ty + "," out.append(field_str) out.append("}") out.append("") # ------------------------------------------------------------------------------ # Source order visitor def write_source_order(out: list[str], ast: Ast) -> None: for group in ast.groups: for node in group.nodes: if node.fields is None or node.custom_source_order: continue name = node.name fields_list = "" body = "" for field in node.fields: if field.skip_source_order(): fields_list += f"{field.name}: _,\n" else: fields_list += f"{field.name},\n" fields_list += "range: _,\n" fields_list += "node_index: _,\n" for field in node.fields_in_source_order(): visitor_name = ( type_to_visitor_function.get( field.parsed_ty.inner, VisitorInfo("") ).name or f"visit_{to_snake_case(field.parsed_ty.inner)}" ) visits_sequence = type_to_visitor_function.get( field.parsed_ty.inner, VisitorInfo("") ).accepts_sequence if field.is_annotation: visitor_name = "visit_annotation" if field.parsed_ty.optional: body += f""" if let Some({field.name}) = {field.name} {{ visitor.{visitor_name}({field.name}); }}\n """ elif not visits_sequence and field.parsed_ty.seq: body += f""" for elm in {field.name} {{ visitor.{visitor_name}(elm); }} """ else: body += f"visitor.{visitor_name}({field.name});\n" visitor_arg_name = "visitor" if len(node.fields_in_source_order()) == 0: visitor_arg_name = "_" out.append(f""" impl {name} {{ pub(crate) fn visit_source_order<'a, V>(&'a self, {visitor_arg_name}: &mut V) where V: SourceOrderVisitor<'a> + ?Sized, {{ let {name} {{ {fields_list} }} = self; {body} }} }} """) # ------------------------------------------------------------------------------ # Format and write output def generate(ast: Ast) -> list[str]: out = [] write_preamble(out) write_owned_enum(out, ast) write_ref_enum(out, ast) write_anynoderef(out, ast) write_root_anynoderef(out, ast) write_nodekind(out, ast) write_node(out, ast) write_source_order(out, ast) return out def write_output(root: Path, out: list[str]) -> None: out_path = root.joinpath("crates", "ruff_python_ast", "src", "generated.rs") out_path.write_text(rustfmt("\n".join(out))) # ------------------------------------------------------------------------------ # Main def main() -> None: root = Path( check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip() ) ast = load_ast(root) out = generate(ast) write_output(root, out) if __name__ == "__main__": main()