mirror of
https://github.com/astral-sh/ruff.git
synced 2025-07-09 22:25:09 +00:00

## Summary Garbage collect ASTs once we are done checking a given file. Queries with a cross-file dependency on the AST will reparse the file on demand. This reduces ty's peak memory usage by ~20-30%. The primary change of this PR is adding a `node_index` field to every AST node, that is assigned by the parser. `ParsedModule` can use this to create a flat index of AST nodes any time the file is parsed (or reparsed). This allows `AstNodeRef` to simply index into the current instance of the `ParsedModule`, instead of storing a pointer directly. The indices are somewhat hackily (using an atomic integer) assigned by the `parsed_module` query instead of by the parser directly. Assigning the indices in source-order in the (recursive) parser turns out to be difficult, and collecting the nodes during semantic indexing is impossible as `SemanticIndex` does not hold onto a specific `ParsedModuleRef`, which the pointers in the flat AST are tied to. This means that we have to do an extra AST traversal to assign and collect the nodes into a flat index, but the small performance impact (~3% on cold runs) seems worth it for the memory savings. Part of https://github.com/astral-sh/ty/issues/214.
1098 lines
33 KiB
Python
1098 lines
33 KiB
Python
#!/usr/bin/python
|
|
# /// script
|
|
# requires-python = ">=3.11"
|
|
# dependencies = []
|
|
# ///
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from subprocess import check_output
|
|
from typing import Any
|
|
|
|
import tomllib
|
|
|
|
# Types that require `crate::`. We can slowly remove these types as we move them to generate scripts.
|
|
types_requiring_crate_prefix = {
|
|
"IpyEscapeKind",
|
|
"ExprContext",
|
|
"Identifier",
|
|
"Number",
|
|
"BytesLiteralValue",
|
|
"StringLiteralValue",
|
|
"FStringValue",
|
|
"TStringValue",
|
|
"Arguments",
|
|
"CmpOp",
|
|
"Comprehension",
|
|
"DictItem",
|
|
"UnaryOp",
|
|
"BoolOp",
|
|
"Operator",
|
|
"Decorator",
|
|
"TypeParams",
|
|
"Parameters",
|
|
"ElifElseClause",
|
|
"WithItem",
|
|
"MatchCase",
|
|
"Alias",
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class VisitorInfo:
|
|
name: str
|
|
accepts_sequence: bool = False
|
|
|
|
|
|
# Map of AST node types to their corresponding visitor information.
|
|
# Only visitors that are different from the default `visit_*` method are included.
|
|
# These visitors either have a different name or accept a sequence of items.
|
|
type_to_visitor_function: dict[str, VisitorInfo] = {
|
|
"TypeParams": VisitorInfo("visit_type_params", True),
|
|
"Parameters": VisitorInfo("visit_parameters", True),
|
|
"Stmt": VisitorInfo("visit_body", True),
|
|
"Arguments": VisitorInfo("visit_arguments", True),
|
|
}
|
|
|
|
|
|
def rustfmt(code: str) -> str:
|
|
return check_output(["rustfmt", "--emit=stdout"], input=code, text=True)
|
|
|
|
|
|
def to_snake_case(node: str) -> str:
|
|
"""Converts CamelCase to snake_case"""
|
|
return re.sub("([A-Z])", r"_\1", node).lower().lstrip("_")
|
|
|
|
|
|
def write_rustdoc(out: list[str], doc: str) -> None:
|
|
for line in doc.split("\n"):
|
|
out.append(f"/// {line}")
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Read AST description
|
|
|
|
|
|
def load_ast(root: Path) -> Ast:
|
|
ast_path = root.joinpath("crates", "ruff_python_ast", "ast.toml")
|
|
with ast_path.open("rb") as ast_file:
|
|
ast = tomllib.load(ast_file)
|
|
return Ast(ast)
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Preprocess
|
|
|
|
|
|
@dataclass
|
|
class Ast:
|
|
"""
|
|
The parsed representation of the `ast.toml` file. Defines all of the Python
|
|
AST syntax nodes, and which groups (`Stmt`, `Expr`, etc.) they belong to.
|
|
"""
|
|
|
|
groups: list[Group]
|
|
ungrouped_nodes: list[Node]
|
|
all_nodes: list[Node]
|
|
|
|
def __init__(self, ast: dict[str, Any]) -> None:
|
|
self.groups = []
|
|
self.ungrouped_nodes = []
|
|
self.all_nodes = []
|
|
for group_name, group in ast.items():
|
|
group = Group(group_name, group)
|
|
self.all_nodes.extend(group.nodes)
|
|
if group_name == "ungrouped":
|
|
self.ungrouped_nodes = group.nodes
|
|
else:
|
|
self.groups.append(group)
|
|
|
|
|
|
@dataclass
|
|
class Group:
|
|
name: str
|
|
nodes: list[Node]
|
|
owned_enum_ty: str
|
|
|
|
add_suffix_to_is_methods: bool
|
|
anynode_is_label: str
|
|
doc: str | None
|
|
|
|
def __init__(self, group_name: str, group: dict[str, Any]) -> None:
|
|
self.name = group_name
|
|
self.owned_enum_ty = group_name
|
|
self.ref_enum_ty = group_name + "Ref"
|
|
self.add_suffix_to_is_methods = group.get("add_suffix_to_is_methods", False)
|
|
self.anynode_is_label = group.get("anynode_is_label", to_snake_case(group_name))
|
|
self.doc = group.get("doc")
|
|
self.nodes = [
|
|
Node(self, node_name, node) for node_name, node in group["nodes"].items()
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class Node:
|
|
name: str
|
|
variant: str
|
|
ty: str
|
|
doc: str | None
|
|
fields: list[Field] | None
|
|
derives: list[str]
|
|
custom_source_order: bool
|
|
source_order: list[str] | None
|
|
|
|
def __init__(self, group: Group, node_name: str, node: dict[str, Any]) -> None:
|
|
self.name = node_name
|
|
self.variant = node.get("variant", node_name.removeprefix(group.name))
|
|
self.ty = f"crate::{node_name}"
|
|
self.fields = None
|
|
fields = node.get("fields")
|
|
if fields is not None:
|
|
self.fields = [Field(f) for f in fields]
|
|
self.custom_source_order = node.get("custom_source_order", False)
|
|
self.derives = node.get("derives", [])
|
|
self.doc = node.get("doc")
|
|
self.source_order = node.get("source_order")
|
|
|
|
def fields_in_source_order(self) -> list[Field]:
|
|
if self.fields is None:
|
|
return []
|
|
if self.source_order is None:
|
|
return list(filter(lambda x: not x.skip_source_order(), self.fields))
|
|
|
|
fields = []
|
|
for field_name in self.source_order:
|
|
field = None
|
|
for field in self.fields:
|
|
if field.skip_source_order():
|
|
continue
|
|
if field.name == field_name:
|
|
field = field
|
|
break
|
|
fields.append(field)
|
|
return fields
|
|
|
|
|
|
@dataclass
|
|
class Field:
|
|
name: str
|
|
ty: str
|
|
_skip_visit: bool
|
|
is_annotation: bool
|
|
parsed_ty: FieldType
|
|
|
|
def __init__(self, field: dict[str, Any]) -> None:
|
|
self.name = field["name"]
|
|
self.ty = field["type"]
|
|
self.parsed_ty = FieldType(self.ty)
|
|
self._skip_visit = field.get("skip_visit", False)
|
|
self.is_annotation = field.get("is_annotation", False)
|
|
|
|
def skip_source_order(self) -> bool:
|
|
return self._skip_visit or self.parsed_ty.inner in [
|
|
"str",
|
|
"ExprContext",
|
|
"Name",
|
|
"u32",
|
|
"bool",
|
|
"Number",
|
|
"IpyEscapeKind",
|
|
]
|
|
|
|
|
|
# Extracts the type argument from the given rust type with AST field type syntax.
|
|
# Box<str> -> str
|
|
# Box<Expr?> -> Expr
|
|
# If the type does not have a type argument, it will return the string.
|
|
# Does not support nested types
|
|
def extract_type_argument(rust_type_str: str) -> str:
|
|
rust_type_str = rust_type_str.replace("*", "")
|
|
rust_type_str = rust_type_str.replace("?", "")
|
|
rust_type_str = rust_type_str.replace("&", "")
|
|
|
|
open_bracket_index = rust_type_str.find("<")
|
|
if open_bracket_index == -1:
|
|
return rust_type_str
|
|
close_bracket_index = rust_type_str.rfind(">")
|
|
if close_bracket_index == -1 or close_bracket_index <= open_bracket_index:
|
|
raise ValueError(f"Brackets are not balanced for type {rust_type_str}")
|
|
inner_type = rust_type_str[open_bracket_index + 1 : close_bracket_index].strip()
|
|
inner_type = inner_type.replace("crate::", "")
|
|
return inner_type
|
|
|
|
|
|
@dataclass
|
|
class FieldType:
|
|
rule: str
|
|
name: str
|
|
inner: str
|
|
seq: bool = False
|
|
optional: bool = False
|
|
slice_: bool = False
|
|
|
|
def __init__(self, rule: str) -> None:
|
|
self.rule = rule
|
|
self.name = ""
|
|
self.inner = extract_type_argument(rule)
|
|
|
|
# The following cases are the limitations of this parser(and not used in the ast.toml):
|
|
# * Rules that involve declaring a sequence with optional items e.g. Vec<Option<...>>
|
|
last_pos = len(rule) - 1
|
|
for i, ch in enumerate(rule):
|
|
if ch == "?":
|
|
if i == last_pos:
|
|
self.optional = True
|
|
else:
|
|
raise ValueError(f"`?` must be at the end: {rule}")
|
|
elif ch == "*":
|
|
if self.slice_: # The * after & is a slice
|
|
continue
|
|
if i == last_pos:
|
|
self.seq = True
|
|
else:
|
|
raise ValueError(f"`*` must be at the end: {rule}")
|
|
elif ch == "&":
|
|
if i == 0 and rule.endswith("*"):
|
|
self.slice_ = True
|
|
else:
|
|
raise ValueError(
|
|
f"`&` must be at the start and end with `*`: {rule}"
|
|
)
|
|
else:
|
|
self.name += ch
|
|
|
|
if self.optional and (self.seq or self.slice_):
|
|
raise ValueError(f"optional field cannot be sequence or slice: {rule}")
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Preamble
|
|
|
|
|
|
def write_preamble(out: list[str]) -> None:
|
|
out.append("""
|
|
// This is a generated file. Don't modify it by hand!
|
|
// Run `crates/ruff_python_ast/generate.py` to re-generate the file.
|
|
|
|
use crate::name::Name;
|
|
use crate::visitor::source_order::SourceOrderVisitor;
|
|
""")
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Owned enum
|
|
|
|
|
|
def write_owned_enum(out: list[str], ast: Ast) -> None:
|
|
"""
|
|
Create an enum for each group that contains an owned copy of a syntax node.
|
|
|
|
```rust
|
|
pub enum TypeParam {
|
|
TypeVar(TypeParamTypeVar),
|
|
TypeVarTuple(TypeParamTypeVarTuple),
|
|
...
|
|
}
|
|
```
|
|
|
|
Also creates:
|
|
- `impl Ranged for TypeParam`
|
|
- `impl HasNodeIndex for TypeParam`
|
|
- `TypeParam::visit_source_order`
|
|
- `impl From<TypeParamTypeVar> for TypeParam`
|
|
- `impl Ranged for TypeParamTypeVar`
|
|
- `impl HasNodeIndex for TypeParamTypeVar`
|
|
- `fn TypeParam::is_type_var() -> bool`
|
|
|
|
If the `add_suffix_to_is_methods` group option is true, then the
|
|
`is_type_var` method will be named `is_type_var_type_param`.
|
|
"""
|
|
|
|
for group in ast.groups:
|
|
out.append("")
|
|
if group.doc is not None:
|
|
write_rustdoc(out, group.doc)
|
|
out.append("#[derive(Clone, Debug, PartialEq)]")
|
|
out.append(f"pub enum {group.owned_enum_ty} {{")
|
|
for node in group.nodes:
|
|
out.append(f"{node.variant}({node.ty}),")
|
|
out.append("}")
|
|
|
|
for node in group.nodes:
|
|
out.append(f"""
|
|
impl From<{node.ty}> for {group.owned_enum_ty} {{
|
|
fn from(node: {node.ty}) -> Self {{
|
|
Self::{node.variant}(node)
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
out.append(f"""
|
|
impl ruff_text_size::Ranged for {group.owned_enum_ty} {{
|
|
fn range(&self) -> ruff_text_size::TextRange {{
|
|
match self {{
|
|
""")
|
|
for node in group.nodes:
|
|
out.append(f"Self::{node.variant}(node) => node.range(),")
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
out.append(f"""
|
|
impl crate::HasNodeIndex for {group.owned_enum_ty} {{
|
|
fn node_index(&self) -> &crate::AtomicNodeIndex {{
|
|
match self {{
|
|
""")
|
|
for node in group.nodes:
|
|
out.append(f"Self::{node.variant}(node) => node.node_index(),")
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
out.append(
|
|
"#[allow(dead_code, clippy::match_wildcard_for_single_variants)]"
|
|
) # Not all is_methods are used
|
|
out.append(f"impl {group.name} {{")
|
|
for node in group.nodes:
|
|
is_name = to_snake_case(node.variant)
|
|
variant_name = node.variant
|
|
match_arm = f"Self::{variant_name}"
|
|
if group.add_suffix_to_is_methods:
|
|
is_name = to_snake_case(node.variant + group.name)
|
|
if len(group.nodes) > 1:
|
|
out.append(f"""
|
|
#[inline]
|
|
pub const fn is_{is_name}(&self) -> bool {{
|
|
matches!(self, {match_arm}(_))
|
|
}}
|
|
|
|
#[inline]
|
|
pub fn {is_name}(self) -> Option<{node.ty}> {{
|
|
match self {{
|
|
{match_arm}(val) => Some(val),
|
|
_ => None,
|
|
}}
|
|
}}
|
|
|
|
#[inline]
|
|
pub fn expect_{is_name}(self) -> {node.ty} {{
|
|
match self {{
|
|
{match_arm}(val) => val,
|
|
_ => panic!("called expect on {{self:?}}"),
|
|
}}
|
|
}}
|
|
|
|
#[inline]
|
|
pub fn as_{is_name}_mut(&mut self) -> Option<&mut {node.ty}> {{
|
|
match self {{
|
|
{match_arm}(val) => Some(val),
|
|
_ => None,
|
|
}}
|
|
}}
|
|
|
|
#[inline]
|
|
pub fn as_{is_name}(&self) -> Option<&{node.ty}> {{
|
|
match self {{
|
|
{match_arm}(val) => Some(val),
|
|
_ => None,
|
|
}}
|
|
}}
|
|
""")
|
|
elif len(group.nodes) == 1:
|
|
out.append(f"""
|
|
#[inline]
|
|
pub const fn is_{is_name}(&self) -> bool {{
|
|
matches!(self, {match_arm}(_))
|
|
}}
|
|
|
|
#[inline]
|
|
pub fn {is_name}(self) -> Option<{node.ty}> {{
|
|
match self {{
|
|
{match_arm}(val) => Some(val),
|
|
}}
|
|
}}
|
|
|
|
#[inline]
|
|
pub fn expect_{is_name}(self) -> {node.ty} {{
|
|
match self {{
|
|
{match_arm}(val) => val,
|
|
}}
|
|
}}
|
|
|
|
#[inline]
|
|
pub fn as_{is_name}_mut(&mut self) -> Option<&mut {node.ty}> {{
|
|
match self {{
|
|
{match_arm}(val) => Some(val),
|
|
}}
|
|
}}
|
|
|
|
#[inline]
|
|
pub fn as_{is_name}(&self) -> Option<&{node.ty}> {{
|
|
match self {{
|
|
{match_arm}(val) => Some(val),
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
out.append("}")
|
|
|
|
for node in ast.all_nodes:
|
|
out.append(f"""
|
|
impl ruff_text_size::Ranged for {node.ty} {{
|
|
fn range(&self) -> ruff_text_size::TextRange {{
|
|
self.range
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
for node in ast.all_nodes:
|
|
out.append(f"""
|
|
impl crate::HasNodeIndex for {node.ty} {{
|
|
fn node_index(&self) -> &crate::AtomicNodeIndex {{
|
|
&self.node_index
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
for group in ast.groups:
|
|
out.append(f"""
|
|
impl {group.owned_enum_ty} {{
|
|
#[allow(unused)]
|
|
pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
|
|
where
|
|
V: crate::visitor::source_order::SourceOrderVisitor<'a> + ?Sized,
|
|
{{
|
|
match self {{
|
|
""")
|
|
for node in group.nodes:
|
|
out.append(
|
|
f"{group.owned_enum_ty}::{node.variant}(node) => node.visit_source_order(visitor),"
|
|
)
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Ref enum
|
|
|
|
|
|
def write_ref_enum(out: list[str], ast: Ast) -> None:
|
|
"""
|
|
Create an enum for each group that contains a reference to a syntax node.
|
|
|
|
```rust
|
|
pub enum TypeParamRef<'a> {
|
|
TypeVar(&'a TypeParamTypeVar),
|
|
TypeVarTuple(&'a TypeParamTypeVarTuple),
|
|
...
|
|
}
|
|
```
|
|
|
|
Also creates:
|
|
- `impl<'a> From<&'a TypeParam> for TypeParamRef<'a>`
|
|
- `impl<'a> From<&'a TypeParamTypeVar> for TypeParamRef<'a>`
|
|
- `impl Ranged for TypeParamRef<'_>`
|
|
- `impl HasNodeIndex for TypeParamRef<'_>`
|
|
- `fn TypeParamRef::is_type_var() -> bool`
|
|
|
|
The name of each variant can be customized via the `variant` node option. If
|
|
the `add_suffix_to_is_methods` group option is true, then the `is_type_var`
|
|
method will be named `is_type_var_type_param`.
|
|
"""
|
|
|
|
for group in ast.groups:
|
|
out.append("")
|
|
if group.doc is not None:
|
|
write_rustdoc(out, group.doc)
|
|
out.append("""#[derive(Clone, Copy, Debug, PartialEq, is_macro::Is)]""")
|
|
out.append(f"""pub enum {group.ref_enum_ty}<'a> {{""")
|
|
for node in group.nodes:
|
|
if group.add_suffix_to_is_methods:
|
|
is_name = to_snake_case(node.variant + group.name)
|
|
out.append(f'#[is(name = "{is_name}")]')
|
|
out.append(f"""{node.variant}(&'a {node.ty}),""")
|
|
out.append("}")
|
|
|
|
out.append(f"""
|
|
impl<'a> From<&'a {group.owned_enum_ty}> for {group.ref_enum_ty}<'a> {{
|
|
fn from(node: &'a {group.owned_enum_ty}) -> Self {{
|
|
match node {{
|
|
""")
|
|
for node in group.nodes:
|
|
out.append(
|
|
f"{group.owned_enum_ty}::{node.variant}(node) => {group.ref_enum_ty}::{node.variant}(node),"
|
|
)
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
for node in group.nodes:
|
|
out.append(f"""
|
|
impl<'a> From<&'a {node.ty}> for {group.ref_enum_ty}<'a> {{
|
|
fn from(node: &'a {node.ty}) -> Self {{
|
|
Self::{node.variant}(node)
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
out.append(f"""
|
|
impl ruff_text_size::Ranged for {group.ref_enum_ty}<'_> {{
|
|
fn range(&self) -> ruff_text_size::TextRange {{
|
|
match self {{
|
|
""")
|
|
for node in group.nodes:
|
|
out.append(f"Self::{node.variant}(node) => node.range(),")
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
out.append(f"""
|
|
impl crate::HasNodeIndex for {group.ref_enum_ty}<'_> {{
|
|
fn node_index(&self) -> &crate::AtomicNodeIndex {{
|
|
match self {{
|
|
""")
|
|
for node in group.nodes:
|
|
out.append(f"Self::{node.variant}(node) => node.node_index(),")
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# AnyNodeRef
|
|
|
|
|
|
def write_anynoderef(out: list[str], ast: Ast) -> None:
|
|
"""
|
|
Create the AnyNodeRef type.
|
|
|
|
```rust
|
|
pub enum AnyNodeRef<'a> {
|
|
...
|
|
TypeParamTypeVar(&'a TypeParamTypeVar),
|
|
TypeParamTypeVarTuple(&'a TypeParamTypeVarTuple),
|
|
...
|
|
}
|
|
```
|
|
|
|
Also creates:
|
|
- `impl<'a> From<&'a TypeParam> for AnyNodeRef<'a>`
|
|
- `impl<'a> From<TypeParamRef<'a>> for AnyNodeRef<'a>`
|
|
- `impl<'a> From<&'a TypeParamTypeVarTuple> for AnyNodeRef<'a>`
|
|
- `impl Ranged for AnyNodeRef<'_>`
|
|
- `impl HasNodeIndex for AnyNodeRef<'_>`
|
|
- `fn AnyNodeRef::as_ptr(&self) -> std::ptr::NonNull<()>`
|
|
- `fn AnyNodeRef::visit_source_order(self, visitor &mut impl SourceOrderVisitor)`
|
|
"""
|
|
|
|
out.append("""
|
|
/// A flattened enumeration of all AST nodes.
|
|
#[derive(Copy, Clone, Debug, is_macro::Is, PartialEq)]
|
|
pub enum AnyNodeRef<'a> {
|
|
""")
|
|
for node in ast.all_nodes:
|
|
out.append(f"""{node.name}(&'a {node.ty}),""")
|
|
out.append("""
|
|
}
|
|
""")
|
|
|
|
for group in ast.groups:
|
|
out.append(f"""
|
|
impl<'a> From<&'a {group.owned_enum_ty}> for AnyNodeRef<'a> {{
|
|
fn from(node: &'a {group.owned_enum_ty}) -> AnyNodeRef<'a> {{
|
|
match node {{
|
|
""")
|
|
for node in group.nodes:
|
|
out.append(
|
|
f"{group.owned_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
|
|
)
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
out.append(f"""
|
|
impl<'a> From<{group.ref_enum_ty}<'a>> for AnyNodeRef<'a> {{
|
|
fn from(node: {group.ref_enum_ty}<'a>) -> AnyNodeRef<'a> {{
|
|
match node {{
|
|
""")
|
|
for node in group.nodes:
|
|
out.append(
|
|
f"{group.ref_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
|
|
)
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
# `as_*` methods to convert from `AnyNodeRef` to e.g. `ExprRef`
|
|
out.append(f"""
|
|
impl<'a> AnyNodeRef<'a> {{
|
|
pub fn as_{to_snake_case(group.ref_enum_ty)}(self) -> Option<{group.ref_enum_ty}<'a>> {{
|
|
match self {{
|
|
""")
|
|
for node in group.nodes:
|
|
out.append(
|
|
f"Self::{node.name}(node) => Some({group.ref_enum_ty}::{node.variant}(node)),"
|
|
)
|
|
out.append("""
|
|
_ => None,
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
for node in ast.all_nodes:
|
|
out.append(f"""
|
|
impl<'a> From<&'a {node.ty}> for AnyNodeRef<'a> {{
|
|
fn from(node: &'a {node.ty}) -> AnyNodeRef<'a> {{
|
|
AnyNodeRef::{node.name}(node)
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
out.append("""
|
|
impl ruff_text_size::Ranged for AnyNodeRef<'_> {
|
|
fn range(&self) -> ruff_text_size::TextRange {
|
|
match self {
|
|
""")
|
|
for node in ast.all_nodes:
|
|
out.append(f"""AnyNodeRef::{node.name}(node) => node.range(),""")
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
out.append("""
|
|
impl crate::HasNodeIndex for AnyNodeRef<'_> {
|
|
fn node_index(&self) -> &crate::AtomicNodeIndex {
|
|
match self {
|
|
""")
|
|
for node in ast.all_nodes:
|
|
out.append(f"""AnyNodeRef::{node.name}(node) => node.node_index(),""")
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
out.append("""
|
|
impl AnyNodeRef<'_> {
|
|
pub fn as_ptr(&self) -> std::ptr::NonNull<()> {
|
|
match self {
|
|
""")
|
|
for node in ast.all_nodes:
|
|
out.append(
|
|
f"AnyNodeRef::{node.name}(node) => std::ptr::NonNull::from(*node).cast(),"
|
|
)
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
out.append("""
|
|
impl<'a> AnyNodeRef<'a> {
|
|
pub fn visit_source_order<'b, V>(self, visitor: &mut V)
|
|
where
|
|
V: crate::visitor::source_order::SourceOrderVisitor<'b> + ?Sized,
|
|
'a: 'b,
|
|
{
|
|
match self {
|
|
""")
|
|
for node in ast.all_nodes:
|
|
out.append(
|
|
f"AnyNodeRef::{node.name}(node) => node.visit_source_order(visitor),"
|
|
)
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
for group in ast.groups:
|
|
out.append(f"""
|
|
impl AnyNodeRef<'_> {{
|
|
pub const fn is_{group.anynode_is_label}(self) -> bool {{
|
|
matches!(self,
|
|
""")
|
|
for i, node in enumerate(group.nodes):
|
|
if i > 0:
|
|
out.append("|")
|
|
out.append(f"""AnyNodeRef::{node.name}(_)""")
|
|
out.append("""
|
|
)
|
|
}
|
|
}
|
|
""")
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# AnyRootNodeRef
|
|
|
|
|
|
def write_root_anynoderef(out: list[str], ast: Ast) -> None:
|
|
"""
|
|
Create the AnyRootNodeRef type.
|
|
|
|
```rust
|
|
pub enum AnyRootNodeRef<'a> {
|
|
...
|
|
TypeParam(&'a TypeParam),
|
|
...
|
|
}
|
|
```
|
|
|
|
Also creates:
|
|
- `impl<'a> From<&'a TypeParam> for AnyRootNodeRef<'a>`
|
|
- `impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a TypeParam`
|
|
- `impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a TypeParamVarTuple`
|
|
- `impl Ranged for AnyRootNodeRef<'_>`
|
|
- `impl HasNodeIndex for AnyRootNodeRef<'_>`
|
|
- `fn AnyRootNodeRef::visit_source_order(self, visitor &mut impl SourceOrderVisitor)`
|
|
"""
|
|
|
|
out.append("""
|
|
/// An enumeration of all AST nodes.
|
|
///
|
|
/// Unlike `AnyNodeRef`, this type does not flatten nested enums, so its variants only
|
|
/// consist of the "root" AST node types. This is useful as it exposes references to the
|
|
/// original enums, not just references to their inner values.
|
|
///
|
|
/// For example, `AnyRootNodeRef::Mod` contains a reference to the `Mod` enum, while
|
|
/// `AnyNodeRef` has top-level `AnyNodeRef::ModModule` and `AnyNodeRef::ModExpression`
|
|
/// variants.
|
|
#[derive(Copy, Clone, Debug, PartialEq)]
|
|
pub enum AnyRootNodeRef<'a> {
|
|
""")
|
|
for group in ast.groups:
|
|
out.append(f"""{group.name}(&'a {group.owned_enum_ty}),""")
|
|
for node in ast.ungrouped_nodes:
|
|
out.append(f"""{node.name}(&'a {node.ty}),""")
|
|
out.append("""
|
|
}
|
|
""")
|
|
|
|
for group in ast.groups:
|
|
out.append(f"""
|
|
impl<'a> From<&'a {group.owned_enum_ty}> for AnyRootNodeRef<'a> {{
|
|
fn from(node: &'a {group.owned_enum_ty}) -> AnyRootNodeRef<'a> {{
|
|
AnyRootNodeRef::{group.name}(node)
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
out.append(f"""
|
|
impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a {group.owned_enum_ty} {{
|
|
type Error = ();
|
|
fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {group.owned_enum_ty}, ()> {{
|
|
match node {{
|
|
AnyRootNodeRef::{group.name}(node) => Ok(node),
|
|
_ => Err(())
|
|
}}
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
for node in group.nodes:
|
|
out.append(f"""
|
|
impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a {node.ty} {{
|
|
type Error = ();
|
|
fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {node.ty}, ()> {{
|
|
match node {{
|
|
AnyRootNodeRef::{group.name}({group.owned_enum_ty}::{node.variant}(node)) => Ok(node),
|
|
_ => Err(())
|
|
}}
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
for node in ast.ungrouped_nodes:
|
|
out.append(f"""
|
|
impl<'a> From<&'a {node.ty}> for AnyRootNodeRef<'a> {{
|
|
fn from(node: &'a {node.ty}) -> AnyRootNodeRef<'a> {{
|
|
AnyRootNodeRef::{node.name}(node)
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
out.append(f"""
|
|
impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a {node.ty} {{
|
|
type Error = ();
|
|
fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {node.ty}, ()> {{
|
|
match node {{
|
|
AnyRootNodeRef::{node.name}(node) => Ok(node),
|
|
_ => Err(())
|
|
}}
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
out.append("""
|
|
impl ruff_text_size::Ranged for AnyRootNodeRef<'_> {
|
|
fn range(&self) -> ruff_text_size::TextRange {
|
|
match self {
|
|
""")
|
|
for group in ast.groups:
|
|
out.append(f"""AnyRootNodeRef::{group.name}(node) => node.range(),""")
|
|
for node in ast.ungrouped_nodes:
|
|
out.append(f"""AnyRootNodeRef::{node.name}(node) => node.range(),""")
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
out.append("""
|
|
impl crate::HasNodeIndex for AnyRootNodeRef<'_> {
|
|
fn node_index(&self) -> &crate::AtomicNodeIndex {
|
|
match self {
|
|
""")
|
|
for group in ast.groups:
|
|
out.append(f"""AnyRootNodeRef::{group.name}(node) => node.node_index(),""")
|
|
for node in ast.ungrouped_nodes:
|
|
out.append(f"""AnyRootNodeRef::{node.name}(node) => node.node_index(),""")
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
out.append("""
|
|
impl<'a> AnyRootNodeRef<'a> {
|
|
pub fn visit_source_order<'b, V>(self, visitor: &mut V)
|
|
where
|
|
V: crate::visitor::source_order::SourceOrderVisitor<'b> + ?Sized,
|
|
'a: 'b,
|
|
{
|
|
match self {
|
|
""")
|
|
for group in ast.groups:
|
|
out.append(
|
|
f"""AnyRootNodeRef::{group.name}(node) => node.visit_source_order(visitor),"""
|
|
)
|
|
for node in ast.ungrouped_nodes:
|
|
out.append(
|
|
f"""AnyRootNodeRef::{node.name}(node) => node.visit_source_order(visitor),"""
|
|
)
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# NodeKind
|
|
|
|
|
|
def write_nodekind(out: list[str], ast: Ast) -> None:
|
|
"""
|
|
Create the NodeKind type.
|
|
|
|
```rust
|
|
pub enum NodeKind {
|
|
...
|
|
TypeParamTypeVar,
|
|
TypeParamTypeVarTuple,
|
|
...
|
|
}
|
|
|
|
Also creates:
|
|
- `fn AnyNodeRef::kind(self) -> NodeKind`
|
|
```
|
|
"""
|
|
|
|
out.append("""
|
|
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
|
pub enum NodeKind {
|
|
""")
|
|
for node in ast.all_nodes:
|
|
out.append(f"""{node.name},""")
|
|
out.append("""
|
|
}
|
|
""")
|
|
|
|
out.append("""
|
|
impl AnyNodeRef<'_> {
|
|
pub const fn kind(self) -> NodeKind {
|
|
match self {
|
|
""")
|
|
for node in ast.all_nodes:
|
|
out.append(f"""AnyNodeRef::{node.name}(_) => NodeKind::{node.name},""")
|
|
out.append("""
|
|
}
|
|
}
|
|
}
|
|
""")
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Node structs
|
|
|
|
|
|
def write_node(out: list[str], ast: Ast) -> None:
|
|
group_names = [group.name for group in ast.groups]
|
|
for group in ast.groups:
|
|
for node in group.nodes:
|
|
if node.fields is None:
|
|
continue
|
|
if node.doc is not None:
|
|
write_rustdoc(out, node.doc)
|
|
out.append(
|
|
"#[derive(Clone, Debug, PartialEq"
|
|
+ "".join(f", {derive}" for derive in node.derives)
|
|
+ ")]"
|
|
)
|
|
name = node.name
|
|
out.append(f"pub struct {name} {{")
|
|
out.append("pub node_index: crate::AtomicNodeIndex,")
|
|
out.append("pub range: ruff_text_size::TextRange,")
|
|
for field in node.fields:
|
|
field_str = f"pub {field.name}: "
|
|
ty = field.parsed_ty
|
|
|
|
rust_ty = f"{field.parsed_ty.name}"
|
|
if ty.name in types_requiring_crate_prefix:
|
|
rust_ty = f"crate::{rust_ty}"
|
|
if ty.slice_:
|
|
rust_ty = f"[{rust_ty}]"
|
|
if (ty.name in group_names or ty.slice_) and ty.seq is False:
|
|
rust_ty = f"Box<{rust_ty}>"
|
|
|
|
if ty.seq:
|
|
rust_ty = f"Vec<{rust_ty}>"
|
|
elif ty.optional:
|
|
rust_ty = f"Option<{rust_ty}>"
|
|
|
|
field_str += rust_ty + ","
|
|
out.append(field_str)
|
|
out.append("}")
|
|
out.append("")
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Source order visitor
|
|
|
|
|
|
def write_source_order(out: list[str], ast: Ast) -> None:
|
|
for group in ast.groups:
|
|
for node in group.nodes:
|
|
if node.fields is None or node.custom_source_order:
|
|
continue
|
|
name = node.name
|
|
fields_list = ""
|
|
body = ""
|
|
|
|
for field in node.fields:
|
|
if field.skip_source_order():
|
|
fields_list += f"{field.name}: _,\n"
|
|
else:
|
|
fields_list += f"{field.name},\n"
|
|
fields_list += "range: _,\n"
|
|
fields_list += "node_index: _,\n"
|
|
|
|
for field in node.fields_in_source_order():
|
|
visitor_name = (
|
|
type_to_visitor_function.get(
|
|
field.parsed_ty.inner, VisitorInfo("")
|
|
).name
|
|
or f"visit_{to_snake_case(field.parsed_ty.inner)}"
|
|
)
|
|
visits_sequence = type_to_visitor_function.get(
|
|
field.parsed_ty.inner, VisitorInfo("")
|
|
).accepts_sequence
|
|
|
|
if field.is_annotation:
|
|
visitor_name = "visit_annotation"
|
|
|
|
if field.parsed_ty.optional:
|
|
body += f"""
|
|
if let Some({field.name}) = {field.name} {{
|
|
visitor.{visitor_name}({field.name});
|
|
}}\n
|
|
"""
|
|
elif not visits_sequence and field.parsed_ty.seq:
|
|
body += f"""
|
|
for elm in {field.name} {{
|
|
visitor.{visitor_name}(elm);
|
|
}}
|
|
"""
|
|
else:
|
|
body += f"visitor.{visitor_name}({field.name});\n"
|
|
|
|
visitor_arg_name = "visitor"
|
|
if len(node.fields_in_source_order()) == 0:
|
|
visitor_arg_name = "_"
|
|
|
|
out.append(f"""
|
|
impl {name} {{
|
|
pub(crate) fn visit_source_order<'a, V>(&'a self, {visitor_arg_name}: &mut V)
|
|
where
|
|
V: SourceOrderVisitor<'a> + ?Sized,
|
|
{{
|
|
let {name} {{
|
|
{fields_list}
|
|
}} = self;
|
|
{body}
|
|
}}
|
|
}}
|
|
""")
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Format and write output
|
|
|
|
|
|
def generate(ast: Ast) -> list[str]:
|
|
out = []
|
|
write_preamble(out)
|
|
write_owned_enum(out, ast)
|
|
write_ref_enum(out, ast)
|
|
write_anynoderef(out, ast)
|
|
write_root_anynoderef(out, ast)
|
|
write_nodekind(out, ast)
|
|
write_node(out, ast)
|
|
write_source_order(out, ast)
|
|
return out
|
|
|
|
|
|
def write_output(root: Path, out: list[str]) -> None:
|
|
out_path = root.joinpath("crates", "ruff_python_ast", "src", "generated.rs")
|
|
out_path.write_text(rustfmt("\n".join(out)))
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# Main
|
|
|
|
|
|
def main() -> None:
|
|
root = Path(
|
|
check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip()
|
|
)
|
|
ast = load_ast(root)
|
|
out = generate(ast)
|
|
write_output(root, out)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|