mirror of
				https://github.com/astral-sh/ruff.git
				synced 2025-10-26 18:06:43 +00:00 
			
		
		
		
	 6f7b1c9bb3
			
		
	
	
		6f7b1c9bb3
		
			
		
	
	
	
	
		
			
			## Summary
Setting `TY_MEMORY_REPORT=full` will generate and print a memory usage
report to the CLI after a `ty check` run:
```
=======SALSA STRUCTS=======
`Definition`                                       metadata=7.24MB   fields=17.38MB  count=181062
`Expression`                                       metadata=4.45MB   fields=5.94MB   count=92804
`member_lookup_with_policy_::interned_arguments`   metadata=1.97MB   fields=2.25MB   count=35176
...
=======SALSA QUERIES=======
`File -> ty_python_semantic::semantic_index::SemanticIndex`
    metadata=11.46MB  fields=88.86MB  count=1638
`Definition -> ty_python_semantic::types::infer::TypeInference`
    metadata=24.52MB  fields=86.68MB  count=146018
`File -> ruff_db::parsed::ParsedModule`
    metadata=0.12MB   fields=69.06MB  count=1642
...
=======SALSA SUMMARY=======
TOTAL MEMORY USAGE: 577.61MB
    struct metadata = 29.00MB
    struct fields = 35.68MB
    memo metadata = 103.87MB
    memo fields = 409.06MB
```
Eventually, we should integrate these numbers into CI in some form. The
one limitation currently is that heap allocations in salsa structs (e.g.
interned values) are not tracked, but memoized values should have full
coverage. We may also want a peak memory usage counter (that accounts
for non-salsa memory), but that is relatively simple to profile manually
(e.g. `time -v ty check`) and would require a compile-time option to
avoid runtime overhead.
		
	
			
		
			
				
	
	
		
			1103 lines
		
	
	
	
		
			34 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1103 lines
		
	
	
	
		
			34 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/python
 | |
| # /// script
 | |
| # requires-python = ">=3.11"
 | |
| # dependencies = []
 | |
| # ///
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import re
 | |
| from dataclasses import dataclass
 | |
| from pathlib import Path
 | |
| from subprocess import check_output
 | |
| from typing import Any
 | |
| 
 | |
| import tomllib
 | |
| 
 | |
| # Types that require `crate::`. We can slowly remove these types as we move them to generate scripts.
 | |
| types_requiring_crate_prefix = {
 | |
|     "IpyEscapeKind",
 | |
|     "ExprContext",
 | |
|     "Identifier",
 | |
|     "Number",
 | |
|     "BytesLiteralValue",
 | |
|     "StringLiteralValue",
 | |
|     "FStringValue",
 | |
|     "TStringValue",
 | |
|     "Arguments",
 | |
|     "CmpOp",
 | |
|     "Comprehension",
 | |
|     "DictItem",
 | |
|     "UnaryOp",
 | |
|     "BoolOp",
 | |
|     "Operator",
 | |
|     "Decorator",
 | |
|     "TypeParams",
 | |
|     "Parameters",
 | |
|     "ElifElseClause",
 | |
|     "WithItem",
 | |
|     "MatchCase",
 | |
|     "Alias",
 | |
| }
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class VisitorInfo:
 | |
|     name: str
 | |
|     accepts_sequence: bool = False
 | |
| 
 | |
| 
 | |
| # Map of AST node types to their corresponding visitor information.
 | |
| # Only visitors that are different from the default `visit_*` method are included.
 | |
| # These visitors either have a different name or accept a sequence of items.
 | |
| type_to_visitor_function: dict[str, VisitorInfo] = {
 | |
|     "TypeParams": VisitorInfo("visit_type_params", True),
 | |
|     "Parameters": VisitorInfo("visit_parameters", True),
 | |
|     "Stmt": VisitorInfo("visit_body", True),
 | |
|     "Arguments": VisitorInfo("visit_arguments", True),
 | |
| }
 | |
| 
 | |
| 
 | |
| def rustfmt(code: str) -> str:
 | |
|     return check_output(["rustfmt", "--emit=stdout"], input=code, text=True)
 | |
| 
 | |
| 
 | |
| def to_snake_case(node: str) -> str:
 | |
|     """Converts CamelCase to snake_case"""
 | |
|     return re.sub("([A-Z])", r"_\1", node).lower().lstrip("_")
 | |
| 
 | |
| 
 | |
| def write_rustdoc(out: list[str], doc: str) -> None:
 | |
|     for line in doc.split("\n"):
 | |
|         out.append(f"/// {line}")
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # Read AST description
 | |
| 
 | |
| 
 | |
| def load_ast(root: Path) -> Ast:
 | |
|     ast_path = root.joinpath("crates", "ruff_python_ast", "ast.toml")
 | |
|     with ast_path.open("rb") as ast_file:
 | |
|         ast = tomllib.load(ast_file)
 | |
|     return Ast(ast)
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # Preprocess
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class Ast:
 | |
|     """
 | |
|     The parsed representation of the `ast.toml` file. Defines all of the Python
 | |
|     AST syntax nodes, and which groups (`Stmt`, `Expr`, etc.) they belong to.
 | |
|     """
 | |
| 
 | |
|     groups: list[Group]
 | |
|     ungrouped_nodes: list[Node]
 | |
|     all_nodes: list[Node]
 | |
| 
 | |
|     def __init__(self, ast: dict[str, Any]) -> None:
 | |
|         self.groups = []
 | |
|         self.ungrouped_nodes = []
 | |
|         self.all_nodes = []
 | |
|         for group_name, group in ast.items():
 | |
|             group = Group(group_name, group)
 | |
|             self.all_nodes.extend(group.nodes)
 | |
|             if group_name == "ungrouped":
 | |
|                 self.ungrouped_nodes = group.nodes
 | |
|             else:
 | |
|                 self.groups.append(group)
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class Group:
 | |
|     name: str
 | |
|     nodes: list[Node]
 | |
|     owned_enum_ty: str
 | |
| 
 | |
|     add_suffix_to_is_methods: bool
 | |
|     anynode_is_label: str
 | |
|     doc: str | None
 | |
| 
 | |
|     def __init__(self, group_name: str, group: dict[str, Any]) -> None:
 | |
|         self.name = group_name
 | |
|         self.owned_enum_ty = group_name
 | |
|         self.ref_enum_ty = group_name + "Ref"
 | |
|         self.add_suffix_to_is_methods = group.get("add_suffix_to_is_methods", False)
 | |
|         self.anynode_is_label = group.get("anynode_is_label", to_snake_case(group_name))
 | |
|         self.doc = group.get("doc")
 | |
|         self.nodes = [
 | |
|             Node(self, node_name, node) for node_name, node in group["nodes"].items()
 | |
|         ]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class Node:
 | |
|     name: str
 | |
|     variant: str
 | |
|     ty: str
 | |
|     doc: str | None
 | |
|     fields: list[Field] | None
 | |
|     derives: list[str]
 | |
|     custom_source_order: bool
 | |
|     source_order: list[str] | None
 | |
| 
 | |
|     def __init__(self, group: Group, node_name: str, node: dict[str, Any]) -> None:
 | |
|         self.name = node_name
 | |
|         self.variant = node.get("variant", node_name.removeprefix(group.name))
 | |
|         self.ty = f"crate::{node_name}"
 | |
|         self.fields = None
 | |
|         fields = node.get("fields")
 | |
|         if fields is not None:
 | |
|             self.fields = [Field(f) for f in fields]
 | |
|         self.custom_source_order = node.get("custom_source_order", False)
 | |
|         self.derives = node.get("derives", [])
 | |
|         self.doc = node.get("doc")
 | |
|         self.source_order = node.get("source_order")
 | |
| 
 | |
|     def fields_in_source_order(self) -> list[Field]:
 | |
|         if self.fields is None:
 | |
|             return []
 | |
|         if self.source_order is None:
 | |
|             return list(filter(lambda x: not x.skip_source_order(), self.fields))
 | |
| 
 | |
|         fields = []
 | |
|         for field_name in self.source_order:
 | |
|             field = None
 | |
|             for field in self.fields:
 | |
|                 if field.skip_source_order():
 | |
|                     continue
 | |
|                 if field.name == field_name:
 | |
|                     field = field
 | |
|                     break
 | |
|             fields.append(field)
 | |
|         return fields
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class Field:
 | |
|     name: str
 | |
|     ty: str
 | |
|     _skip_visit: bool
 | |
|     is_annotation: bool
 | |
|     parsed_ty: FieldType
 | |
| 
 | |
|     def __init__(self, field: dict[str, Any]) -> None:
 | |
|         self.name = field["name"]
 | |
|         self.ty = field["type"]
 | |
|         self.parsed_ty = FieldType(self.ty)
 | |
|         self._skip_visit = field.get("skip_visit", False)
 | |
|         self.is_annotation = field.get("is_annotation", False)
 | |
| 
 | |
|     def skip_source_order(self) -> bool:
 | |
|         return self._skip_visit or self.parsed_ty.inner in [
 | |
|             "str",
 | |
|             "ExprContext",
 | |
|             "Name",
 | |
|             "u32",
 | |
|             "bool",
 | |
|             "Number",
 | |
|             "IpyEscapeKind",
 | |
|         ]
 | |
| 
 | |
| 
 | |
| # Extracts the type argument from the given rust type with AST field type syntax.
 | |
| # Box<str> -> str
 | |
| # Box<Expr?> -> Expr
 | |
| # If the type does not have a type argument, it will return the string.
 | |
| # Does not support nested types
 | |
| def extract_type_argument(rust_type_str: str) -> str:
 | |
|     rust_type_str = rust_type_str.replace("*", "")
 | |
|     rust_type_str = rust_type_str.replace("?", "")
 | |
|     rust_type_str = rust_type_str.replace("&", "")
 | |
| 
 | |
|     open_bracket_index = rust_type_str.find("<")
 | |
|     if open_bracket_index == -1:
 | |
|         return rust_type_str
 | |
|     close_bracket_index = rust_type_str.rfind(">")
 | |
|     if close_bracket_index == -1 or close_bracket_index <= open_bracket_index:
 | |
|         raise ValueError(f"Brackets are not balanced for type {rust_type_str}")
 | |
|     inner_type = rust_type_str[open_bracket_index + 1 : close_bracket_index].strip()
 | |
|     inner_type = inner_type.replace("crate::", "")
 | |
|     return inner_type
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class FieldType:
 | |
|     rule: str
 | |
|     name: str
 | |
|     inner: str
 | |
|     seq: bool = False
 | |
|     optional: bool = False
 | |
|     slice_: bool = False
 | |
| 
 | |
|     def __init__(self, rule: str) -> None:
 | |
|         self.rule = rule
 | |
|         self.name = ""
 | |
|         self.inner = extract_type_argument(rule)
 | |
| 
 | |
|         # The following cases are the limitations of this parser(and not used in the ast.toml):
 | |
|         # * Rules that involve declaring a sequence with optional items e.g. Vec<Option<...>>
 | |
|         last_pos = len(rule) - 1
 | |
|         for i, ch in enumerate(rule):
 | |
|             if ch == "?":
 | |
|                 if i == last_pos:
 | |
|                     self.optional = True
 | |
|                 else:
 | |
|                     raise ValueError(f"`?` must be at the end: {rule}")
 | |
|             elif ch == "*":
 | |
|                 if self.slice_:  # The * after & is a slice
 | |
|                     continue
 | |
|                 if i == last_pos:
 | |
|                     self.seq = True
 | |
|                 else:
 | |
|                     raise ValueError(f"`*` must be at the end: {rule}")
 | |
|             elif ch == "&":
 | |
|                 if i == 0 and rule.endswith("*"):
 | |
|                     self.slice_ = True
 | |
|                 else:
 | |
|                     raise ValueError(
 | |
|                         f"`&` must be at the start and end with `*`: {rule}"
 | |
|                     )
 | |
|             else:
 | |
|                 self.name += ch
 | |
| 
 | |
|         if self.optional and (self.seq or self.slice_):
 | |
|             raise ValueError(f"optional field cannot be sequence or slice: {rule}")
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # Preamble
 | |
| 
 | |
| 
 | |
| def write_preamble(out: list[str]) -> None:
 | |
|     out.append("""
 | |
|     // This is a generated file. Don't modify it by hand!
 | |
|     // Run `crates/ruff_python_ast/generate.py` to re-generate the file.
 | |
| 
 | |
|     use crate::name::Name;
 | |
|     use crate::visitor::source_order::SourceOrderVisitor;
 | |
|     """)
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # Owned enum
 | |
| 
 | |
| 
 | |
| def write_owned_enum(out: list[str], ast: Ast) -> None:
 | |
|     """
 | |
|     Create an enum for each group that contains an owned copy of a syntax node.
 | |
| 
 | |
|     ```rust
 | |
|     pub enum TypeParam {
 | |
|         TypeVar(TypeParamTypeVar),
 | |
|         TypeVarTuple(TypeParamTypeVarTuple),
 | |
|         ...
 | |
|     }
 | |
|     ```
 | |
| 
 | |
|     Also creates:
 | |
|     - `impl Ranged for TypeParam`
 | |
|     - `impl HasNodeIndex for TypeParam`
 | |
|     - `TypeParam::visit_source_order`
 | |
|     - `impl From<TypeParamTypeVar> for TypeParam`
 | |
|     - `impl Ranged for TypeParamTypeVar`
 | |
|     - `impl HasNodeIndex for TypeParamTypeVar`
 | |
|     - `fn TypeParam::is_type_var() -> bool`
 | |
| 
 | |
|     If the `add_suffix_to_is_methods` group option is true, then the
 | |
|     `is_type_var` method will be named `is_type_var_type_param`.
 | |
|     """
 | |
| 
 | |
|     for group in ast.groups:
 | |
|         out.append("")
 | |
|         if group.doc is not None:
 | |
|             write_rustdoc(out, group.doc)
 | |
|         out.append("#[derive(Clone, Debug, PartialEq)]")
 | |
|         out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]')
 | |
|         out.append(f"pub enum {group.owned_enum_ty} {{")
 | |
|         for node in group.nodes:
 | |
|             out.append(f"{node.variant}({node.ty}),")
 | |
|         out.append("}")
 | |
| 
 | |
|         for node in group.nodes:
 | |
|             out.append(f"""
 | |
|             impl From<{node.ty}> for {group.owned_enum_ty} {{
 | |
|                 fn from(node: {node.ty}) -> Self {{
 | |
|                     Self::{node.variant}(node)
 | |
|                 }}
 | |
|             }}
 | |
|             """)
 | |
| 
 | |
|         out.append(f"""
 | |
|         impl ruff_text_size::Ranged for {group.owned_enum_ty} {{
 | |
|             fn range(&self) -> ruff_text_size::TextRange {{
 | |
|                 match self {{
 | |
|         """)
 | |
|         for node in group.nodes:
 | |
|             out.append(f"Self::{node.variant}(node) => node.range(),")
 | |
|         out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|         """)
 | |
| 
 | |
|         out.append(f"""
 | |
|         impl crate::HasNodeIndex for {group.owned_enum_ty} {{
 | |
|             fn node_index(&self) -> &crate::AtomicNodeIndex {{
 | |
|                 match self {{
 | |
|         """)
 | |
|         for node in group.nodes:
 | |
|             out.append(f"Self::{node.variant}(node) => node.node_index(),")
 | |
|         out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|         """)
 | |
| 
 | |
|         out.append(
 | |
|             "#[allow(dead_code, clippy::match_wildcard_for_single_variants)]"
 | |
|         )  # Not all is_methods are used
 | |
|         out.append(f"impl {group.name} {{")
 | |
|         for node in group.nodes:
 | |
|             is_name = to_snake_case(node.variant)
 | |
|             variant_name = node.variant
 | |
|             match_arm = f"Self::{variant_name}"
 | |
|             if group.add_suffix_to_is_methods:
 | |
|                 is_name = to_snake_case(node.variant + group.name)
 | |
|             if len(group.nodes) > 1:
 | |
|                 out.append(f"""
 | |
|                     #[inline]
 | |
|                     pub const fn is_{is_name}(&self) -> bool {{
 | |
|                         matches!(self, {match_arm}(_))
 | |
|                     }}
 | |
| 
 | |
|                     #[inline]
 | |
|                     pub fn {is_name}(self) -> Option<{node.ty}> {{
 | |
|                         match self {{
 | |
|                             {match_arm}(val) => Some(val),
 | |
|                             _ => None,
 | |
|                         }}
 | |
|                     }}
 | |
| 
 | |
|                     #[inline]
 | |
|                     pub fn expect_{is_name}(self) -> {node.ty} {{
 | |
|                         match self {{
 | |
|                             {match_arm}(val) => val,
 | |
|                             _ => panic!("called expect on {{self:?}}"),
 | |
|                         }}
 | |
|                     }}
 | |
| 
 | |
|                     #[inline]
 | |
|                     pub fn as_{is_name}_mut(&mut self) -> Option<&mut {node.ty}> {{
 | |
|                         match self {{
 | |
|                             {match_arm}(val) => Some(val),
 | |
|                             _ => None,
 | |
|                         }}
 | |
|                     }}
 | |
| 
 | |
|                     #[inline]
 | |
|                     pub fn as_{is_name}(&self) -> Option<&{node.ty}> {{
 | |
|                         match self {{
 | |
|                             {match_arm}(val) => Some(val),
 | |
|                             _ => None,
 | |
|                         }}
 | |
|                     }}
 | |
|                            """)
 | |
|             elif len(group.nodes) == 1:
 | |
|                 out.append(f"""
 | |
|                     #[inline]
 | |
|                     pub const fn is_{is_name}(&self) -> bool {{
 | |
|                         matches!(self, {match_arm}(_))
 | |
|                     }}
 | |
| 
 | |
|                     #[inline]
 | |
|                     pub fn {is_name}(self) -> Option<{node.ty}> {{
 | |
|                         match self {{
 | |
|                             {match_arm}(val) => Some(val),
 | |
|                         }}
 | |
|                     }}
 | |
| 
 | |
|                     #[inline]
 | |
|                     pub fn expect_{is_name}(self) -> {node.ty} {{
 | |
|                         match self {{
 | |
|                             {match_arm}(val) => val,
 | |
|                         }}
 | |
|                     }}
 | |
| 
 | |
|                     #[inline]
 | |
|                     pub fn as_{is_name}_mut(&mut self) -> Option<&mut {node.ty}> {{
 | |
|                         match self {{
 | |
|                             {match_arm}(val) => Some(val),
 | |
|                         }}
 | |
|                     }}
 | |
| 
 | |
|                     #[inline]
 | |
|                     pub fn as_{is_name}(&self) -> Option<&{node.ty}> {{
 | |
|                         match self {{
 | |
|                             {match_arm}(val) => Some(val),
 | |
|                         }}
 | |
|                     }}
 | |
|                            """)
 | |
| 
 | |
|         out.append("}")
 | |
| 
 | |
|     for node in ast.all_nodes:
 | |
|         out.append(f"""
 | |
|             impl ruff_text_size::Ranged for {node.ty} {{
 | |
|                 fn range(&self) -> ruff_text_size::TextRange {{
 | |
|                     self.range
 | |
|                 }}
 | |
|             }}
 | |
|         """)
 | |
| 
 | |
|     for node in ast.all_nodes:
 | |
|         out.append(f"""
 | |
|             impl crate::HasNodeIndex for {node.ty} {{
 | |
|                 fn node_index(&self) -> &crate::AtomicNodeIndex {{
 | |
|                     &self.node_index
 | |
|                 }}
 | |
|             }}
 | |
|         """)
 | |
| 
 | |
|     for group in ast.groups:
 | |
|         out.append(f"""
 | |
|             impl {group.owned_enum_ty} {{
 | |
|                 #[allow(unused)]
 | |
|                 pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
 | |
|                 where
 | |
|                     V: crate::visitor::source_order::SourceOrderVisitor<'a> + ?Sized,
 | |
|                 {{
 | |
|                     match self {{
 | |
|         """)
 | |
|         for node in group.nodes:
 | |
|             out.append(
 | |
|                 f"{group.owned_enum_ty}::{node.variant}(node) => node.visit_source_order(visitor),"
 | |
|             )
 | |
|         out.append("""
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|         """)
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # Ref enum
 | |
| 
 | |
| 
 | |
| def write_ref_enum(out: list[str], ast: Ast) -> None:
 | |
|     """
 | |
|     Create an enum for each group that contains a reference to a syntax node.
 | |
| 
 | |
|     ```rust
 | |
|     pub enum TypeParamRef<'a> {
 | |
|         TypeVar(&'a TypeParamTypeVar),
 | |
|         TypeVarTuple(&'a TypeParamTypeVarTuple),
 | |
|         ...
 | |
|     }
 | |
|     ```
 | |
| 
 | |
|     Also creates:
 | |
|     - `impl<'a> From<&'a TypeParam> for TypeParamRef<'a>`
 | |
|     - `impl<'a> From<&'a TypeParamTypeVar> for TypeParamRef<'a>`
 | |
|     - `impl Ranged for TypeParamRef<'_>`
 | |
|     - `impl HasNodeIndex for TypeParamRef<'_>`
 | |
|     - `fn TypeParamRef::is_type_var() -> bool`
 | |
| 
 | |
|     The name of each variant can be customized via the `variant` node option. If
 | |
|     the `add_suffix_to_is_methods` group option is true, then the `is_type_var`
 | |
|     method will be named `is_type_var_type_param`.
 | |
|     """
 | |
| 
 | |
|     for group in ast.groups:
 | |
|         out.append("")
 | |
|         if group.doc is not None:
 | |
|             write_rustdoc(out, group.doc)
 | |
|         out.append("""#[derive(Clone, Copy, Debug, PartialEq, is_macro::Is)]""")
 | |
|         out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]')
 | |
|         out.append(f"""pub enum {group.ref_enum_ty}<'a> {{""")
 | |
|         for node in group.nodes:
 | |
|             if group.add_suffix_to_is_methods:
 | |
|                 is_name = to_snake_case(node.variant + group.name)
 | |
|                 out.append(f'#[is(name = "{is_name}")]')
 | |
|             out.append(f"""{node.variant}(&'a {node.ty}),""")
 | |
|         out.append("}")
 | |
| 
 | |
|         out.append(f"""
 | |
|             impl<'a> From<&'a {group.owned_enum_ty}> for {group.ref_enum_ty}<'a> {{
 | |
|                 fn from(node: &'a {group.owned_enum_ty}) -> Self {{
 | |
|                     match node {{
 | |
|         """)
 | |
|         for node in group.nodes:
 | |
|             out.append(
 | |
|                 f"{group.owned_enum_ty}::{node.variant}(node) => {group.ref_enum_ty}::{node.variant}(node),"
 | |
|             )
 | |
|         out.append("""
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|         """)
 | |
| 
 | |
|         for node in group.nodes:
 | |
|             out.append(f"""
 | |
|             impl<'a> From<&'a {node.ty}> for {group.ref_enum_ty}<'a> {{
 | |
|                 fn from(node: &'a {node.ty}) -> Self {{
 | |
|                     Self::{node.variant}(node)
 | |
|                 }}
 | |
|             }}
 | |
|             """)
 | |
| 
 | |
|         out.append(f"""
 | |
|         impl ruff_text_size::Ranged for {group.ref_enum_ty}<'_> {{
 | |
|             fn range(&self) -> ruff_text_size::TextRange {{
 | |
|                 match self {{
 | |
|         """)
 | |
|         for node in group.nodes:
 | |
|             out.append(f"Self::{node.variant}(node) => node.range(),")
 | |
|         out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|         """)
 | |
| 
 | |
|         out.append(f"""
 | |
|         impl crate::HasNodeIndex for {group.ref_enum_ty}<'_> {{
 | |
|             fn node_index(&self) -> &crate::AtomicNodeIndex {{
 | |
|                 match self {{
 | |
|         """)
 | |
|         for node in group.nodes:
 | |
|             out.append(f"Self::{node.variant}(node) => node.node_index(),")
 | |
|         out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|         """)
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # AnyNodeRef
 | |
| 
 | |
| 
 | |
| def write_anynoderef(out: list[str], ast: Ast) -> None:
 | |
|     """
 | |
|     Create the AnyNodeRef type.
 | |
| 
 | |
|     ```rust
 | |
|     pub enum AnyNodeRef<'a> {
 | |
|         ...
 | |
|         TypeParamTypeVar(&'a TypeParamTypeVar),
 | |
|         TypeParamTypeVarTuple(&'a TypeParamTypeVarTuple),
 | |
|         ...
 | |
|     }
 | |
|     ```
 | |
| 
 | |
|     Also creates:
 | |
|     - `impl<'a> From<&'a TypeParam> for AnyNodeRef<'a>`
 | |
|     - `impl<'a> From<TypeParamRef<'a>> for AnyNodeRef<'a>`
 | |
|     - `impl<'a> From<&'a TypeParamTypeVarTuple> for AnyNodeRef<'a>`
 | |
|     - `impl Ranged for AnyNodeRef<'_>`
 | |
|     - `impl HasNodeIndex for AnyNodeRef<'_>`
 | |
|     - `fn AnyNodeRef::as_ptr(&self) -> std::ptr::NonNull<()>`
 | |
|     - `fn AnyNodeRef::visit_source_order(self, visitor &mut impl SourceOrderVisitor)`
 | |
|     """
 | |
| 
 | |
|     out.append("""
 | |
|     /// A flattened enumeration of all AST nodes.
 | |
|     #[derive(Copy, Clone, Debug, is_macro::Is, PartialEq)]
 | |
|     #[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]
 | |
|     pub enum AnyNodeRef<'a> {
 | |
|     """)
 | |
|     for node in ast.all_nodes:
 | |
|         out.append(f"""{node.name}(&'a {node.ty}),""")
 | |
|     out.append("""
 | |
|     }
 | |
|     """)
 | |
| 
 | |
|     for group in ast.groups:
 | |
|         out.append(f"""
 | |
|             impl<'a> From<&'a {group.owned_enum_ty}> for AnyNodeRef<'a> {{
 | |
|                 fn from(node: &'a {group.owned_enum_ty}) -> AnyNodeRef<'a> {{
 | |
|                     match node {{
 | |
|         """)
 | |
|         for node in group.nodes:
 | |
|             out.append(
 | |
|                 f"{group.owned_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
 | |
|             )
 | |
|         out.append("""
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|         """)
 | |
| 
 | |
|         out.append(f"""
 | |
|             impl<'a> From<{group.ref_enum_ty}<'a>> for AnyNodeRef<'a> {{
 | |
|                 fn from(node: {group.ref_enum_ty}<'a>) -> AnyNodeRef<'a> {{
 | |
|                     match node {{
 | |
|         """)
 | |
|         for node in group.nodes:
 | |
|             out.append(
 | |
|                 f"{group.ref_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
 | |
|             )
 | |
|         out.append("""
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|         """)
 | |
| 
 | |
|         # `as_*` methods to convert from `AnyNodeRef` to e.g. `ExprRef`
 | |
|         out.append(f"""
 | |
|             impl<'a> AnyNodeRef<'a> {{
 | |
|                 pub fn as_{to_snake_case(group.ref_enum_ty)}(self) -> Option<{group.ref_enum_ty}<'a>> {{
 | |
|                     match self {{
 | |
|         """)
 | |
|         for node in group.nodes:
 | |
|             out.append(
 | |
|                 f"Self::{node.name}(node) => Some({group.ref_enum_ty}::{node.variant}(node)),"
 | |
|             )
 | |
|         out.append("""
 | |
|                         _ => None,
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|         """)
 | |
| 
 | |
|     for node in ast.all_nodes:
 | |
|         out.append(f"""
 | |
|             impl<'a> From<&'a {node.ty}> for AnyNodeRef<'a> {{
 | |
|                 fn from(node: &'a {node.ty}) -> AnyNodeRef<'a> {{
 | |
|                     AnyNodeRef::{node.name}(node)
 | |
|                 }}
 | |
|             }}
 | |
|         """)
 | |
| 
 | |
|     out.append("""
 | |
|         impl ruff_text_size::Ranged for AnyNodeRef<'_> {
 | |
|             fn range(&self) -> ruff_text_size::TextRange {
 | |
|                 match self {
 | |
|     """)
 | |
|     for node in ast.all_nodes:
 | |
|         out.append(f"""AnyNodeRef::{node.name}(node) => node.range(),""")
 | |
|     out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     """)
 | |
| 
 | |
|     out.append("""
 | |
|         impl crate::HasNodeIndex for AnyNodeRef<'_> {
 | |
|             fn node_index(&self) -> &crate::AtomicNodeIndex {
 | |
|                 match self {
 | |
|     """)
 | |
|     for node in ast.all_nodes:
 | |
|         out.append(f"""AnyNodeRef::{node.name}(node) => node.node_index(),""")
 | |
|     out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     """)
 | |
| 
 | |
|     out.append("""
 | |
|         impl AnyNodeRef<'_> {
 | |
|             pub fn as_ptr(&self) -> std::ptr::NonNull<()> {
 | |
|                 match self {
 | |
|     """)
 | |
|     for node in ast.all_nodes:
 | |
|         out.append(
 | |
|             f"AnyNodeRef::{node.name}(node) => std::ptr::NonNull::from(*node).cast(),"
 | |
|         )
 | |
|     out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     """)
 | |
| 
 | |
|     out.append("""
 | |
|         impl<'a> AnyNodeRef<'a> {
 | |
|             pub fn visit_source_order<'b, V>(self, visitor: &mut V)
 | |
|             where
 | |
|                 V: crate::visitor::source_order::SourceOrderVisitor<'b> + ?Sized,
 | |
|                 'a: 'b,
 | |
|             {
 | |
|                 match self {
 | |
|     """)
 | |
|     for node in ast.all_nodes:
 | |
|         out.append(
 | |
|             f"AnyNodeRef::{node.name}(node) => node.visit_source_order(visitor),"
 | |
|         )
 | |
|     out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     """)
 | |
| 
 | |
|     for group in ast.groups:
 | |
|         out.append(f"""
 | |
|         impl AnyNodeRef<'_> {{
 | |
|             pub const fn is_{group.anynode_is_label}(self) -> bool {{
 | |
|                 matches!(self,
 | |
|         """)
 | |
|         for i, node in enumerate(group.nodes):
 | |
|             if i > 0:
 | |
|                 out.append("|")
 | |
|             out.append(f"""AnyNodeRef::{node.name}(_)""")
 | |
|         out.append("""
 | |
|                 )
 | |
|             }
 | |
|         }
 | |
|         """)
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # AnyRootNodeRef
 | |
| 
 | |
| 
 | |
| def write_root_anynoderef(out: list[str], ast: Ast) -> None:
 | |
|     """
 | |
|     Create the AnyRootNodeRef type.
 | |
| 
 | |
|     ```rust
 | |
|     pub enum AnyRootNodeRef<'a> {
 | |
|         ...
 | |
|         TypeParam(&'a TypeParam),
 | |
|         ...
 | |
|     }
 | |
|     ```
 | |
| 
 | |
|     Also creates:
 | |
|     - `impl<'a> From<&'a TypeParam> for AnyRootNodeRef<'a>`
 | |
|     - `impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a TypeParam`
 | |
|     - `impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a TypeParamVarTuple`
 | |
|     - `impl Ranged for AnyRootNodeRef<'_>`
 | |
|     - `impl HasNodeIndex for AnyRootNodeRef<'_>`
 | |
|     - `fn AnyRootNodeRef::visit_source_order(self, visitor &mut impl SourceOrderVisitor)`
 | |
|     """
 | |
| 
 | |
|     out.append("""
 | |
|     /// An enumeration of all AST nodes.
 | |
|     ///
 | |
|     /// Unlike `AnyNodeRef`, this type does not flatten nested enums, so its variants only
 | |
|     /// consist of the "root" AST node types. This is useful as it exposes references to the
 | |
|     /// original enums, not just references to their inner values.
 | |
|     ///
 | |
|     /// For example, `AnyRootNodeRef::Mod` contains a reference to the `Mod` enum, while
 | |
|     /// `AnyNodeRef` has top-level `AnyNodeRef::ModModule` and `AnyNodeRef::ModExpression`
 | |
|     /// variants.
 | |
|     #[derive(Copy, Clone, Debug, PartialEq)]
 | |
|     #[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]
 | |
|     pub enum AnyRootNodeRef<'a> {
 | |
|     """)
 | |
|     for group in ast.groups:
 | |
|         out.append(f"""{group.name}(&'a {group.owned_enum_ty}),""")
 | |
|     for node in ast.ungrouped_nodes:
 | |
|         out.append(f"""{node.name}(&'a {node.ty}),""")
 | |
|     out.append("""
 | |
|     }
 | |
|     """)
 | |
| 
 | |
|     for group in ast.groups:
 | |
|         out.append(f"""
 | |
|             impl<'a> From<&'a {group.owned_enum_ty}> for AnyRootNodeRef<'a> {{
 | |
|                 fn from(node: &'a {group.owned_enum_ty}) -> AnyRootNodeRef<'a> {{
 | |
|                         AnyRootNodeRef::{group.name}(node)
 | |
|                 }}
 | |
|             }}
 | |
|         """)
 | |
| 
 | |
|         out.append(f"""
 | |
|             impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a {group.owned_enum_ty} {{
 | |
|                 type Error = ();
 | |
|                 fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {group.owned_enum_ty}, ()> {{
 | |
|                     match node {{
 | |
|                         AnyRootNodeRef::{group.name}(node) => Ok(node),
 | |
|                         _ => Err(())
 | |
|                     }}
 | |
|                 }}
 | |
|             }}
 | |
|         """)
 | |
| 
 | |
|         for node in group.nodes:
 | |
|             out.append(f"""
 | |
|                 impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a {node.ty} {{
 | |
|                     type Error = ();
 | |
|                     fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {node.ty}, ()> {{
 | |
|                         match node {{
 | |
|                             AnyRootNodeRef::{group.name}({group.owned_enum_ty}::{node.variant}(node)) => Ok(node),
 | |
|                             _ => Err(())
 | |
|                         }}
 | |
|                     }}
 | |
|                 }}
 | |
|             """)
 | |
| 
 | |
|     for node in ast.ungrouped_nodes:
 | |
|         out.append(f"""
 | |
|             impl<'a> From<&'a {node.ty}> for AnyRootNodeRef<'a> {{
 | |
|                 fn from(node: &'a {node.ty}) -> AnyRootNodeRef<'a> {{
 | |
|                     AnyRootNodeRef::{node.name}(node)
 | |
|                 }}
 | |
|             }}
 | |
|         """)
 | |
| 
 | |
|         out.append(f"""
 | |
|             impl<'a> TryFrom<AnyRootNodeRef<'a>> for &'a {node.ty} {{
 | |
|                 type Error = ();
 | |
|                 fn try_from(node: AnyRootNodeRef<'a>) -> Result<&'a {node.ty}, ()> {{
 | |
|                     match node {{
 | |
|                         AnyRootNodeRef::{node.name}(node) => Ok(node),
 | |
|                         _ => Err(())
 | |
|                     }}
 | |
|                 }}
 | |
|             }}
 | |
|         """)
 | |
| 
 | |
|     out.append("""
 | |
|         impl ruff_text_size::Ranged for AnyRootNodeRef<'_> {
 | |
|             fn range(&self) -> ruff_text_size::TextRange {
 | |
|                 match self {
 | |
|     """)
 | |
|     for group in ast.groups:
 | |
|         out.append(f"""AnyRootNodeRef::{group.name}(node) => node.range(),""")
 | |
|     for node in ast.ungrouped_nodes:
 | |
|         out.append(f"""AnyRootNodeRef::{node.name}(node) => node.range(),""")
 | |
|     out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     """)
 | |
| 
 | |
|     out.append("""
 | |
|         impl crate::HasNodeIndex for AnyRootNodeRef<'_> {
 | |
|             fn node_index(&self) -> &crate::AtomicNodeIndex {
 | |
|                 match self {
 | |
|     """)
 | |
|     for group in ast.groups:
 | |
|         out.append(f"""AnyRootNodeRef::{group.name}(node) => node.node_index(),""")
 | |
|     for node in ast.ungrouped_nodes:
 | |
|         out.append(f"""AnyRootNodeRef::{node.name}(node) => node.node_index(),""")
 | |
|     out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     """)
 | |
| 
 | |
|     out.append("""
 | |
|         impl<'a> AnyRootNodeRef<'a> {
 | |
|             pub fn visit_source_order<'b, V>(self, visitor: &mut V)
 | |
|             where
 | |
|                 V: crate::visitor::source_order::SourceOrderVisitor<'b> + ?Sized,
 | |
|                 'a: 'b,
 | |
|             {
 | |
|                 match self {
 | |
|     """)
 | |
|     for group in ast.groups:
 | |
|         out.append(
 | |
|             f"""AnyRootNodeRef::{group.name}(node) => node.visit_source_order(visitor),"""
 | |
|         )
 | |
|     for node in ast.ungrouped_nodes:
 | |
|         out.append(
 | |
|             f"""AnyRootNodeRef::{node.name}(node) => node.visit_source_order(visitor),"""
 | |
|         )
 | |
|     out.append("""
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|     """)
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # NodeKind
 | |
| 
 | |
| 
 | |
| def write_nodekind(out: list[str], ast: Ast) -> None:
 | |
|     """
 | |
|     Create the NodeKind type.
 | |
| 
 | |
|     ```rust
 | |
|     pub enum NodeKind {
 | |
|         ...
 | |
|         TypeParamTypeVar,
 | |
|         TypeParamTypeVarTuple,
 | |
|         ...
 | |
|     }
 | |
| 
 | |
|     Also creates:
 | |
|     - `fn AnyNodeRef::kind(self) -> NodeKind`
 | |
|     ```
 | |
|     """
 | |
| 
 | |
|     out.append("""
 | |
|     #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
 | |
|     pub enum NodeKind {
 | |
|     """)
 | |
|     for node in ast.all_nodes:
 | |
|         out.append(f"""{node.name},""")
 | |
|     out.append("""
 | |
|     }
 | |
|     """)
 | |
| 
 | |
|     out.append("""
 | |
|     impl AnyNodeRef<'_> {
 | |
|         pub const fn kind(self) -> NodeKind {
 | |
|             match self {
 | |
|     """)
 | |
|     for node in ast.all_nodes:
 | |
|         out.append(f"""AnyNodeRef::{node.name}(_) => NodeKind::{node.name},""")
 | |
|     out.append("""
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     """)
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # Node structs
 | |
| 
 | |
| 
 | |
| def write_node(out: list[str], ast: Ast) -> None:
 | |
|     group_names = [group.name for group in ast.groups]
 | |
|     for group in ast.groups:
 | |
|         for node in group.nodes:
 | |
|             if node.fields is None:
 | |
|                 continue
 | |
|             if node.doc is not None:
 | |
|                 write_rustdoc(out, node.doc)
 | |
|             out.append(
 | |
|                 "#[derive(Clone, Debug, PartialEq"
 | |
|                 + "".join(f", {derive}" for derive in node.derives)
 | |
|                 + ")]"
 | |
|             )
 | |
|             out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]')
 | |
|             name = node.name
 | |
|             out.append(f"pub struct {name} {{")
 | |
|             out.append("pub node_index: crate::AtomicNodeIndex,")
 | |
|             out.append("pub range: ruff_text_size::TextRange,")
 | |
|             for field in node.fields:
 | |
|                 field_str = f"pub {field.name}: "
 | |
|                 ty = field.parsed_ty
 | |
| 
 | |
|                 rust_ty = f"{field.parsed_ty.name}"
 | |
|                 if ty.name in types_requiring_crate_prefix:
 | |
|                     rust_ty = f"crate::{rust_ty}"
 | |
|                 if ty.slice_:
 | |
|                     rust_ty = f"[{rust_ty}]"
 | |
|                 if (ty.name in group_names or ty.slice_) and ty.seq is False:
 | |
|                     rust_ty = f"Box<{rust_ty}>"
 | |
| 
 | |
|                 if ty.seq:
 | |
|                     rust_ty = f"Vec<{rust_ty}>"
 | |
|                 elif ty.optional:
 | |
|                     rust_ty = f"Option<{rust_ty}>"
 | |
| 
 | |
|                 field_str += rust_ty + ","
 | |
|                 out.append(field_str)
 | |
|             out.append("}")
 | |
|             out.append("")
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # Source order visitor
 | |
| 
 | |
| 
 | |
| def write_source_order(out: list[str], ast: Ast) -> None:
 | |
|     for group in ast.groups:
 | |
|         for node in group.nodes:
 | |
|             if node.fields is None or node.custom_source_order:
 | |
|                 continue
 | |
|             name = node.name
 | |
|             fields_list = ""
 | |
|             body = ""
 | |
| 
 | |
|             for field in node.fields:
 | |
|                 if field.skip_source_order():
 | |
|                     fields_list += f"{field.name}: _,\n"
 | |
|                 else:
 | |
|                     fields_list += f"{field.name},\n"
 | |
|             fields_list += "range: _,\n"
 | |
|             fields_list += "node_index: _,\n"
 | |
| 
 | |
|             for field in node.fields_in_source_order():
 | |
|                 visitor_name = (
 | |
|                     type_to_visitor_function.get(
 | |
|                         field.parsed_ty.inner, VisitorInfo("")
 | |
|                     ).name
 | |
|                     or f"visit_{to_snake_case(field.parsed_ty.inner)}"
 | |
|                 )
 | |
|                 visits_sequence = type_to_visitor_function.get(
 | |
|                     field.parsed_ty.inner, VisitorInfo("")
 | |
|                 ).accepts_sequence
 | |
| 
 | |
|                 if field.is_annotation:
 | |
|                     visitor_name = "visit_annotation"
 | |
| 
 | |
|                 if field.parsed_ty.optional:
 | |
|                     body += f"""
 | |
|                             if let Some({field.name}) = {field.name} {{
 | |
|                                 visitor.{visitor_name}({field.name});
 | |
|                             }}\n
 | |
|                       """
 | |
|                 elif not visits_sequence and field.parsed_ty.seq:
 | |
|                     body += f"""
 | |
|                             for elm in {field.name} {{
 | |
|                                 visitor.{visitor_name}(elm);
 | |
|                             }}
 | |
|                      """
 | |
|                 else:
 | |
|                     body += f"visitor.{visitor_name}({field.name});\n"
 | |
| 
 | |
|             visitor_arg_name = "visitor"
 | |
|             if len(node.fields_in_source_order()) == 0:
 | |
|                 visitor_arg_name = "_"
 | |
| 
 | |
|             out.append(f"""
 | |
| impl {name} {{
 | |
|     pub(crate) fn visit_source_order<'a, V>(&'a self, {visitor_arg_name}: &mut V)
 | |
|     where
 | |
|         V: SourceOrderVisitor<'a> + ?Sized,
 | |
|     {{
 | |
|         let {name} {{
 | |
|             {fields_list}
 | |
|         }} = self;
 | |
|         {body}
 | |
|     }}
 | |
| }}
 | |
|         """)
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # Format and write output
 | |
| 
 | |
| 
 | |
| def generate(ast: Ast) -> list[str]:
 | |
|     out = []
 | |
|     write_preamble(out)
 | |
|     write_owned_enum(out, ast)
 | |
|     write_ref_enum(out, ast)
 | |
|     write_anynoderef(out, ast)
 | |
|     write_root_anynoderef(out, ast)
 | |
|     write_nodekind(out, ast)
 | |
|     write_node(out, ast)
 | |
|     write_source_order(out, ast)
 | |
|     return out
 | |
| 
 | |
| 
 | |
| def write_output(root: Path, out: list[str]) -> None:
 | |
|     out_path = root.joinpath("crates", "ruff_python_ast", "src", "generated.rs")
 | |
|     out_path.write_text(rustfmt("\n".join(out)))
 | |
| 
 | |
| 
 | |
| # ------------------------------------------------------------------------------
 | |
| # Main
 | |
| 
 | |
| 
 | |
| def main() -> None:
 | |
|     root = Path(
 | |
|         check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip()
 | |
|     )
 | |
|     ast = load_ast(root)
 | |
|     out = generate(ast)
 | |
|     write_output(root, out)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 |