# spell-checker:words dfn dfns # ! /usr/bin/env python """Generate Rust code from an ASDL description.""" import sys import json import textwrap import re from argparse import ArgumentParser from pathlib import Path from typing import Optional, Dict, Any import asdl TABSIZE = 4 AUTO_GEN_MESSAGE = "// File automatically generated by {}.\n\n" BUILTIN_TYPE_NAMES = { "identifier": "Identifier", "string": "String", "int": "Int", "constant": "Constant", } assert BUILTIN_TYPE_NAMES.keys() == asdl.builtin_types BUILTIN_INT_NAMES = { "simple": "bool", "is_async": "bool", "conversion": "ConversionFlag", } RENAME_MAP = { "cmpop": "cmp_op", "unaryop": "unary_op", "boolop": "bool_op", "excepthandler": "except_handler", "withitem": "with_item", } RUST_KEYWORDS = { "if", "while", "for", "return", "match", "try", "await", "yield", "in", "mod", "type", } attributes = [ asdl.Field("int", "lineno"), asdl.Field("int", "col_offset"), asdl.Field("int", "end_lineno"), asdl.Field("int", "end_col_offset"), ] ORIGINAL_NODE_WARNING = "NOTE: This type is different from original Python AST." arg_with_default = asdl.Type( "arg_with_default", asdl.Product( [ asdl.Field("arg", "def"), asdl.Field( "expr", "default", opt=True ), # order is important for cost-free borrow! ], ), ) arg_with_default.doc = f""" An alternative type of AST `arg`. This is used for each function argument that might have a default value. Used by `Arguments` original type. {ORIGINAL_NODE_WARNING} """.strip() alt_arguments = asdl.Type( "alt:arguments", asdl.Product( [ asdl.Field("arg_with_default", "posonlyargs", seq=True), asdl.Field("arg_with_default", "args", seq=True), asdl.Field("arg", "vararg", opt=True), asdl.Field("arg_with_default", "kwonlyargs", seq=True), asdl.Field("arg", "kwarg", opt=True), ] ), ) alt_arguments.doc = f""" An alternative type of AST `arguments`. This is parser-friendly and human-friendly definition of function arguments. This form also has advantage to implement pre-order traverse. `defaults` and `kw_defaults` fields are removed and the default values are placed under each `arg_with_default` typed argument. `vararg` and `kwarg` are still typed as `arg` because they never can have a default value. The matching Python style AST type is [PythonArguments]. While [PythonArguments] has ordered `kwonlyargs` fields by default existence, [Arguments] has location-ordered kwonlyargs fields. {ORIGINAL_NODE_WARNING} """.strip() # Must be used only for rust types, not python types CUSTOM_TYPES = [ alt_arguments, arg_with_default, ] CUSTOM_REPLACEMENTS = { "arguments": alt_arguments, } CUSTOM_ATTACHMENTS = [ arg_with_default, ] def maybe_custom(type): return CUSTOM_REPLACEMENTS.get(type.name, type) def rust_field_name(name): name = rust_type_name(name) return re.sub(r"(?" @property def name(self): return self.type.name @property def is_type(self): return isinstance(self.type, asdl.Type) @property def is_product(self): return self.is_type and isinstance(self.type.value, asdl.Product) @property def is_sum(self): return self.is_type and isinstance(self.type.value, asdl.Sum) @property def has_expr(self): return self.is_product and any( f.type != "identifier" for f in self.type.value.fields ) @property def is_custom(self): return self.type.name in [t.name for t in CUSTOM_TYPES] @property def is_custom_replaced(self): return self.type.name in CUSTOM_REPLACEMENTS @property def custom(self): if self.type.name in CUSTOM_REPLACEMENTS: return CUSTOM_REPLACEMENTS[self.type.name] return self.type def no_cfg(self, typeinfo): if self.is_product: return self.has_attributes elif self.enum_name: return typeinfo[self.enum_name].has_attributes else: return self.has_attributes @property def rust_name(self): return rust_type_name(self.name) @property def full_field_name(self): name = self.name if name.startswith("alt:"): name = name[4:] if self.enum_name is None: return name else: return f"{self.enum_name}_{rust_field_name(name)}" @property def full_type_name(self): name = self.name if name.startswith("alt:"): name = name[4:] rust_name = rust_type_name(name) if self.enum_name is not None: rust_name = rust_type_name(self.enum_name) + rust_name if self.is_custom_replaced: rust_name = "Python" + rust_name return rust_name def determine_user_data(self, type_info, stack): if self.name in stack: return None stack.add(self.name) for child, child_seq in self.children: if child in asdl.builtin_types: continue child_info = type_info[child] child_has_user_data = child_info.determine_user_data(type_info, stack) if self.has_user_data is None and child_has_user_data is True: self.has_user_data = True stack.remove(self.name) return self.has_user_data class TypeInfoMixin: type_info: Dict[str, TypeInfo] def customized_type_info(self, type_name): info = self.type_info[type_name] return self.type_info[info.custom.name] def has_user_data(self, typ): return self.type_info[typ].has_user_data def apply_generics(self, typ, *generics): needs_generics = not self.type_info[typ].is_simple if needs_generics: return [f"<{g}>" for g in generics] else: return ["" for g in generics] class EmitVisitor(asdl.VisitorBase, TypeInfoMixin): """Visit that emits lines""" def __init__(self, file, type_info): self.file = file self.type_info = type_info self.identifiers = set() super(EmitVisitor, self).__init__() def emit_identifier(self, name): name = str(name) if name in self.identifiers: return self.emit("_Py_IDENTIFIER(%s);" % name, 0) self.identifiers.add(name) def emit(self, line, depth): if line: line = (" " * TABSIZE * depth) + textwrap.dedent(line) self.file.write(line + "\n") class FindUserDataTypesVisitor(asdl.VisitorBase): def __init__(self, type_info): self.type_info = type_info super().__init__() def visitModule(self, mod): for dfn in mod.dfns + CUSTOM_TYPES: self.visit(dfn) stack = set() for info in self.type_info.values(): info.determine_user_data(self.type_info, stack) def visitType(self, type): key = type.name info = self.type_info[key] = TypeInfo(type) self.visit(type.value, info) def visitSum(self, sum, info): type = info.type info.is_simple = is_simple(sum) for cons in sum.types: self.visit(cons, type, info.is_simple) if info.is_simple: info.has_user_data = False return for t in sum.types: self.add_children(t.name, t.fields) if len(sum.types) > 1: info.boxed = True if sum.attributes: # attributes means located, which has the `range: R` field info.has_user_data = True info.has_attributes = True for variant in sum.types: self.add_children(type.name, variant.fields) def visitConstructor(self, cons, type, simple): info = self.type_info[cons.name] = TypeInfo(cons) info.enum_name = type.name info.is_simple = simple def visitProduct(self, product, info): type = info.type if product.attributes: # attributes means located, which has the `range: R` field info.has_user_data = True info.has_attributes = True if len(product.fields) > 2: info.boxed = True self.add_children(type.name, product.fields) def add_children(self, name, fields): self.type_info[name].children.update( (field.type, field.seq) for field in fields ) def rust_field(field_name): if field_name in RUST_KEYWORDS: field_name += "_" return field_name class StructVisitor(EmitVisitor): """Visitor to generate type-defs for AST.""" def __init__(self, *args, **kw): super().__init__(*args, **kw) def emit_attrs(self, depth): self.emit("#[derive(Clone, Debug, PartialEq)]", depth) def emit_range(self, has_attributes, depth): if has_attributes: self.emit("pub range: R,", depth + 1) else: self.emit("pub range: OptionalRange,", depth + 1) def visitModule(self, mod): self.emit_attrs(0) self.emit( """ #[derive(is_macro::Is)] pub enum Ast { """, 0, ) for dfn in mod.dfns: info = self.customized_type_info(dfn.name) dfn = info.custom rust_name = info.full_type_name generics = "" if self.type_info[dfn.name].is_simple else "" if dfn.name == "mod": # This is exceptional rule to other enums. # Unlike other enums, this is justified because `Mod` is only used as # the top node of parsing result and never a child node of other nodes. # Because it will be very rarely used in very particular applications, # "ast_" prefix to everywhere seems less useful. self.emit('#[is(name = "module")]', 1) self.emit(f"{rust_name}({rust_name}{generics}),", 1) self.emit( """ } impl Node for Ast { const NAME: &'static str = "AST"; const FIELD_NAMES: &'static [&'static str] = &[]; } """, 0, ) for dfn in mod.dfns: info = self.customized_type_info(dfn.name) rust_name = info.full_type_name generics = "" if self.type_info[dfn.name].is_simple else "" self.emit( f""" impl From<{rust_name}{generics}> for Ast {{ fn from(node: {rust_name}{generics}) -> Self {{ Ast::{rust_name}(node) }} }} """, 0, ) for dfn in mod.dfns + CUSTOM_TYPES: self.visit(dfn) def visitType(self, type, depth=0): if hasattr(type, "doc"): doc = "/// " + type.doc.replace("\n", "\n/// ") + "\n" else: doc = f"/// See also [{type.name}](https://docs.python.org/3/library/ast.html#ast.{type.name})" self.emit(doc, depth) self.visit(type.value, type, depth) def visitSum(self, sum, type, depth): if is_simple(sum): self.simple_sum(sum, type, depth) else: self.sum_with_constructors(sum, type, depth) (generics_applied,) = self.apply_generics(type.name, "R") self.emit( f""" impl{generics_applied} Node for {rust_type_name(type.name)}{generics_applied} {{ const NAME: &'static str = "{type.name}"; const FIELD_NAMES: &'static [&'static str] = &[]; }} """, depth, ) def simple_sum(self, sum, type, depth): rust_name = rust_type_name(type.name) self.emit_attrs(depth) self.emit("#[derive(is_macro::Is, Copy, Hash, Eq)]", depth) self.emit(f"pub enum {rust_name} {{", depth) for cons in sum.types: self.emit(f"{cons.name},", depth + 1) self.emit("}", depth) self.emit(f"impl {rust_name} {{", depth) needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types) if needs_escape: prefix = rust_field_name(type.name) + "_" else: prefix = "" for cons in sum.types: self.emit( f""" #[inline] pub const fn {prefix}{rust_field_name(cons.name)}(&self) -> Option<{rust_name}{cons.name}> {{ match self {{ {rust_name}::{cons.name} => Some({rust_name}{cons.name}), _ => None, }} }} """, depth, ) self.emit("}", depth) self.emit("", depth) for cons in sum.types: self.emit( f""" pub struct {rust_name}{cons.name}; impl From<{rust_name}{cons.name}> for {rust_name} {{ fn from(_: {rust_name}{cons.name}) -> Self {{ {rust_name}::{cons.name} }} }} impl From<{rust_name}{cons.name}> for Ast {{ fn from(_: {rust_name}{cons.name}) -> Self {{ {rust_name}::{cons.name}.into() }} }} impl Node for {rust_name}{cons.name} {{ const NAME: &'static str = "{cons.name}"; const FIELD_NAMES: &'static [&'static str] = &[]; }} impl std::cmp::PartialEq<{rust_name}> for {rust_name}{cons.name} {{ #[inline] fn eq(&self, other: &{rust_name}) -> bool {{ matches!(other, {rust_name}::{cons.name}) }} }} """, 0, ) def sum_with_constructors(self, sum, type, depth): type_info = self.type_info[type.name] rust_name = rust_type_name(type.name) self.emit_attrs(depth) self.emit("#[derive(is_macro::Is)]", depth) self.emit(f"pub enum {rust_name} {{", depth) needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types) for t in sum.types: if needs_escape: self.emit( f'#[is(name = "{rust_field_name(t.name)}_{rust_name.lower()}")]', depth + 1, ) self.emit(f"{t.name}({rust_name}{t.name}),", depth + 1) self.emit("}", depth) self.emit("", depth) for t in sum.types: self.sum_subtype_struct(type_info, t, rust_name, depth) def sum_subtype_struct(self, sum_type_info, t, rust_name, depth): self.emit(f"""/// See also [{t.name}](https://docs.python.org/3/library/ast.html#ast.{t.name})""", depth) self.emit_attrs(depth) payload_name = f"{rust_name}{t.name}" self.emit(f"pub struct {payload_name} {{", depth) self.emit_range(sum_type_info.has_attributes, depth) for f in t.fields: self.visit(f, sum_type_info, "pub ", depth + 1, t.name) assert sum_type_info.has_attributes == self.type_info[t.name].no_cfg( self.type_info ) self.emit("}", depth) field_names = [f'"{f.name}"' for f in t.fields] self.emit( f""" impl Node for {payload_name} {{ const NAME: &'static str = "{t.name}"; const FIELD_NAMES: &'static [&'static str] = &[{', '.join(field_names)}]; }} impl From<{payload_name}> for {rust_name} {{ fn from(payload: {payload_name}) -> Self {{ {rust_name}::{t.name}(payload) }} }} impl From<{payload_name}> for Ast {{ fn from(payload: {payload_name}) -> Self {{ {rust_name}::from(payload).into() }} }} """, depth, ) self.emit("", depth) def visitConstructor(self, cons, parent, depth): if cons.fields: self.emit(f"{cons.name} {{", depth) for f in cons.fields: self.visit(f, parent, "", depth + 1, cons.name) self.emit("},", depth) else: self.emit(f"{cons.name},", depth) def visitField(self, field, parent, vis, depth, constructor=None): try: field_type = self.customized_type_info(field.type) typ = field_type.full_type_name except KeyError: field_type = None typ = rust_type_name(field.type) if field_type and not field_type.is_simple: typ = f"{typ}" # don't box if we're doing Vec, but do box if we're doing Vec>> if ( field_type and field_type.boxed and (not (parent.is_product or field.seq) or field.opt) ): typ = f"Box<{typ}>" if field.opt or ( # When a dictionary literal contains dictionary unpacking (e.g., `{**d}`), # the expression to be unpacked goes in `values` with a `None` at the corresponding # position in `keys`. To handle this, the type of `keys` needs to be `Option>`. constructor == "Dict" and field.name == "keys" ): typ = f"Option<{typ}>" if field.seq: typ = f"Vec<{typ}>" if typ == "Int": typ = BUILTIN_INT_NAMES.get(field.name, typ) name = rust_field(field.name) self.emit(f"{vis}{name}: {typ},", depth) def visitProduct(self, product, type, depth): type_info = self.type_info[type.name] product_name = type_info.full_type_name self.emit_attrs(depth) self.emit(f"pub struct {product_name} {{", depth) self.emit_range(product.attributes, depth + 1) for f in product.fields: self.visit(f, type_info, "pub ", depth + 1) assert bool(product.attributes) == type_info.no_cfg(self.type_info) self.emit("}", depth) field_names = [f'"{f.name}"' for f in product.fields] self.emit( f""" impl Node for {product_name} {{ const NAME: &'static str = "{type.name}"; const FIELD_NAMES: &'static [&'static str] = &[ {', '.join(field_names)} ]; }} """, depth, ) class FoldTraitDefVisitor(EmitVisitor): def visitModule(self, mod, depth): self.emit("pub trait Fold {", depth) self.emit("type TargetU;", depth + 1) self.emit("type Error;", depth + 1) self.emit("type UserContext;", depth + 1) self.emit( """ fn will_map_user(&mut self, user: &U) -> Self::UserContext; #[cfg(feature = "all-nodes-with-ranges")] fn will_map_user_cfg(&mut self, user: &U) -> Self::UserContext { self.will_map_user(user) } #[cfg(not(feature = "all-nodes-with-ranges"))] fn will_map_user_cfg(&mut self, _user: &crate::EmptyRange) -> crate::EmptyRange { crate::EmptyRange::default() } fn map_user(&mut self, user: U, context: Self::UserContext) -> Result; #[cfg(feature = "all-nodes-with-ranges")] fn map_user_cfg(&mut self, user: U, context: Self::UserContext) -> Result { self.map_user(user, context) } #[cfg(not(feature = "all-nodes-with-ranges"))] fn map_user_cfg( &mut self, _user: crate::EmptyRange, _context: crate::EmptyRange, ) -> Result, Self::Error> { Ok(crate::EmptyRange::default()) } """, depth + 1, ) self.emit( """ fn fold>(&mut self, node: X) -> Result { node.fold(self) }""", depth + 1, ) for dfn in mod.dfns + [arg_with_default]: dfn = maybe_custom(dfn) self.visit(dfn, depth + 2) self.emit("}", depth) def visitType(self, type, depth): info = self.type_info[type.name] apply_u, apply_target_u = self.apply_generics(info.name, "U", "Self::TargetU") enum_name = info.full_type_name self.emit( f"fn fold_{info.full_field_name}(&mut self, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, Self::Error> {{", depth, ) self.emit(f"fold_{info.full_field_name}(self, node)", depth + 1) self.emit("}", depth) if isinstance(type.value, asdl.Sum) and not is_simple(type.value): for cons in type.value.types: self.visit(cons, type, depth) def visitConstructor(self, cons, type, depth): info = self.type_info[type.name] apply_u, apply_target_u = self.apply_generics(type.name, "U", "Self::TargetU") enum_name = rust_type_name(type.name) func_name = f"fold_{info.full_field_name}_{rust_field_name(cons.name)}" self.emit( f"fn {func_name}(&mut self, node: {enum_name}{cons.name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, Self::Error> {{", depth, ) self.emit(f"{func_name}(self, node)", depth + 1) self.emit("}", depth) class FoldImplVisitor(EmitVisitor): def visitModule(self, mod, depth): for dfn in mod.dfns + [arg_with_default]: dfn = maybe_custom(dfn) self.visit(dfn, depth) def visitType(self, type, depth=0): self.visit(type.value, type, depth) def visitSum(self, sum, type, depth): name = type.name apply_t, apply_u, apply_target_u = self.apply_generics( name, "T", "U", "F::TargetU" ) enum_name = rust_type_name(name) simple = is_simple(sum) self.emit(f"impl Foldable for {enum_name}{apply_t} {{", depth) self.emit(f"type Mapped = {enum_name}{apply_u};", depth + 1) self.emit( "fn fold + ?Sized>(self, folder: &mut F) -> Result {", depth + 1, ) self.emit(f"folder.fold_{name}(self)", depth + 2) self.emit("}", depth + 1) self.emit("}", depth) self.emit( f"pub fn fold_{name} + ?Sized>(#[allow(unused)] folder: &mut F, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, F::Error> {{", depth, ) if simple: self.emit("Ok(node) }", depth + 1) return self.emit("let folded = match node {", depth + 1) for cons in sum.types: self.emit( f"{enum_name}::{cons.name}(cons) => {enum_name}::{cons.name}(Foldable::fold(cons, folder)?),", depth + 1, ) self.emit("};", depth + 1) self.emit("Ok(folded)", depth + 1) self.emit("}", depth) for cons in sum.types: self.visit(cons, type, depth) def visitConstructor(self, cons, type, depth): apply_t, apply_u, apply_target_u = self.apply_generics( type.name, "T", "U", "F::TargetU" ) info = self.type_info[type.name] enum_name = info.full_type_name cons_type_name = f"{enum_name}{cons.name}" self.emit(f"impl Foldable for {cons_type_name}{apply_t} {{", depth) self.emit(f"type Mapped = {cons_type_name}{apply_u};", depth + 1) self.emit( "fn fold + ?Sized>(self, folder: &mut F) -> Result {", depth + 1, ) self.emit( f"folder.fold_{info.full_field_name}_{rust_field_name(cons.name)}(self)", depth + 2, ) self.emit("}", depth + 1) self.emit("}", depth) self.emit( f"pub fn fold_{info.full_field_name}_{rust_field_name(cons.name)} + ?Sized>(#[allow(unused)] folder: &mut F, node: {cons_type_name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, F::Error> {{", depth, ) fields_pattern = self.make_pattern(cons.fields) map_user_suffix = "" if info.has_attributes else "_cfg" self.emit( f""" let {cons_type_name} {{ {fields_pattern} }} = node; let context = folder.will_map_user{map_user_suffix}(&range); """, depth + 3, ) self.fold_fields(cons.fields, depth + 3) self.emit( f"let range = folder.map_user{map_user_suffix}(range, context)?;", depth + 3, ) self.composite_fields(f"{cons_type_name}", cons.fields, depth + 3) self.emit("}", depth + 2) def visitProduct(self, product, type, depth): info = self.type_info[type.name] name = type.name apply_t, apply_u, apply_target_u = self.apply_generics( name, "T", "U", "F::TargetU" ) struct_name = info.full_type_name has_attributes = bool(product.attributes) self.emit(f"impl Foldable for {struct_name}{apply_t} {{", depth) self.emit(f"type Mapped = {struct_name}{apply_u};", depth + 1) self.emit( "fn fold + ?Sized>(self, folder: &mut F) -> Result {", depth + 1, ) self.emit(f"folder.fold_{info.full_field_name}(self)", depth + 2) self.emit("}", depth + 1) self.emit("}", depth) self.emit( f"pub fn fold_{info.full_field_name} + ?Sized>(#[allow(unused)] folder: &mut F, node: {struct_name}{apply_u}) -> Result<{struct_name}{apply_target_u}, F::Error> {{", depth, ) fields_pattern = self.make_pattern(product.fields) self.emit(f"let {struct_name} {{ {fields_pattern} }} = node;", depth + 1) map_user_suffix = "" if has_attributes else "_cfg" self.emit( f"let context = folder.will_map_user{map_user_suffix}(&range);", depth + 3 ) self.fold_fields(product.fields, depth + 1) self.emit( f"let range = folder.map_user{map_user_suffix}(range, context)?;", depth + 3 ) self.composite_fields(struct_name, product.fields, depth + 1) self.emit("}", depth) def make_pattern(self, fields): body = ",".join(rust_field(f.name) for f in fields) if body: body += "," body += "range" return body def fold_fields(self, fields, depth): for field in fields: name = rust_field(field.name) self.emit(f"let {name} = Foldable::fold({name}, folder)?;", depth + 1) def composite_fields(self, header, fields, depth): self.emit(f"Ok({header} {{", depth) for field in fields: name = rust_field(field.name) self.emit(f"{name},", depth + 1) self.emit("range,", depth + 1) self.emit("})", depth) class FoldModuleVisitor(EmitVisitor): def visitModule(self, mod): depth = 0 FoldTraitDefVisitor(self.file, self.type_info).visit(mod, depth) FoldImplVisitor(self.file, self.type_info).visit(mod, depth) class VisitorModuleVisitor(StructVisitor): def full_name(self, name): type_info = self.type_info[name] if type_info.enum_name: return f"{type_info.enum_name}_{name}" else: return name def node_type_name(self, name): type_info = self.type_info[name] if type_info.enum_name: return f"{rust_type_name(type_info.enum_name)}{rust_type_name(name)}" else: return rust_type_name(name) def visitModule(self, mod, depth=0): self.emit("#[allow(unused_variables)]", depth) self.emit("pub trait Visitor {", depth) for dfn in mod.dfns: dfn = self.customized_type_info(dfn.name).type self.visit(dfn, depth + 1) self.emit("}", depth) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): if is_simple(sum): self.simple_sum(sum, name, depth) else: self.sum_with_constructors(sum, name, depth) def emit_visitor(self, nodename, depth, has_node=True): type_info = self.type_info[nodename] node_type = type_info.full_type_name (generic,) = self.apply_generics(nodename, "R") self.emit( f"fn visit_{type_info.full_field_name}(&mut self, node: {node_type}{generic}) {{", depth, ) if has_node: self.emit( f"self.generic_visit_{type_info.full_field_name}(node)", depth + 1 ) self.emit("}", depth) def emit_generic_visitor_signature(self, nodename, depth, has_node=True): type_info = self.type_info[nodename] if has_node: node_type = type_info.full_type_name else: node_type = "()" (generic,) = self.apply_generics(nodename, "R") self.emit( f"fn generic_visit_{type_info.full_field_name}(&mut self, node: {node_type}{generic}) {{", depth, ) def emit_empty_generic_visitor(self, nodename, depth): self.emit_generic_visitor_signature(nodename, depth) self.emit("}", depth) def simple_sum(self, sum, name, depth): self.emit_visitor(name, depth) self.emit_empty_generic_visitor(name, depth) def visit_match_for_type(self, nodename, rust_name, type_, depth): self.emit(f"{rust_name}::{type_.name}", depth) self.emit("(data)", depth) self.emit( f"=> self.visit_{nodename}_{rust_field_name(type_.name)}(data),", depth ) def visit_sum_type(self, name, type_, depth): self.emit_visitor(type_.name, depth, has_node=type_.fields) if not type_.fields: return self.emit_generic_visitor_signature(type_.name, depth, has_node=True) for field in type_.fields: if field.type in CUSTOM_REPLACEMENTS: type_name = CUSTOM_REPLACEMENTS[field.type].name else: type_name = field.type field_name = rust_field(field.name) field_type = self.type_info.get(type_name) if not (field_type and field_type.has_user_data): continue if field.opt: self.emit(f"if let Some(value) = node.{field_name} {{", depth + 1) elif field.seq: iterable = f"node.{field_name}" if type_.name == "Dict" and field.name == "keys": iterable = f"{iterable}.into_iter().flatten()" self.emit(f"for value in {iterable} {{", depth + 1) else: self.emit("{", depth + 1) self.emit(f"let value = node.{field_name};", depth + 2) variable = "value" if field_type.boxed and (not field.seq or field.opt): variable = "*" + variable type_info = self.type_info[field_type.name] self.emit(f"self.visit_{type_info.full_field_name}({variable});", depth + 2) self.emit("}", depth + 1) self.emit("}", depth) def sum_with_constructors(self, sum, name, depth): if not sum.attributes: return enum_name = rust_type_name(name) self.emit_visitor(name, depth) self.emit_generic_visitor_signature(name, depth) depth += 1 self.emit("match node {", depth) for t in sum.types: self.visit_match_for_type(name, enum_name, t, depth + 1) self.emit("}", depth) depth -= 1 self.emit("}", depth) # Now for the visitors for the types for t in sum.types: self.visit_sum_type(name, t, depth) def visitProduct(self, product, name, depth): self.emit_visitor(name, depth) self.emit_empty_generic_visitor(name, depth) class RangedDefVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns + CUSTOM_TYPES: self.visit(dfn) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): info = self.type_info[name] self.emit_type_alias(info) if info.is_simple: for ty in sum.types: variant_info = self.type_info[ty.name] self.emit_type_alias(variant_info) return sum_match_arms = "" for ty in sum.types: variant_info = self.type_info[ty.name] sum_match_arms += ( f" Self::{variant_info.rust_name}(node) => node.range()," ) self.emit_type_alias(variant_info) self.emit_ranged_impl(variant_info) if not info.no_cfg(self.type_info): self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) self.emit( f""" impl Ranged for crate::{info.full_type_name} {{ fn range(&self) -> TextRange {{ match self {{ {sum_match_arms} }} }} }} """.lstrip(), 0, ) def visitProduct(self, product, name, depth): info = self.type_info[name] self.emit_type_alias(info) self.emit_ranged_impl(info) def emit_type_alias(self, info): return # disable generics = "" if info.is_simple else "::" self.emit( f"pub type {info.full_type_name} = crate::generic::{info.full_type_name}{generics};", 0, ) self.emit("", 0) def emit_ranged_impl(self, info): if not info.no_cfg(self.type_info): self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) self.file.write( f""" impl Ranged for crate::generic::{info.full_type_name}:: {{ fn range(&self) -> TextRange {{ self.range }} }} """.strip() ) class LocatedDefVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns + CUSTOM_TYPES: self.visit(dfn) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): info = self.type_info[name] self.emit_type_alias(info) if info.is_simple: for ty in sum.types: variant_info = self.type_info[ty.name] self.emit_type_alias(variant_info) return sum_match_arms = "" for ty in sum.types: variant_info = self.type_info[ty.name] sum_match_arms += ( f" Self::{variant_info.rust_name}(node) => node.range()," ) self.emit_type_alias(variant_info) self.emit_located_impl(variant_info) if not info.no_cfg(self.type_info): cfg = '#[cfg(feature = "all-nodes-with-ranges")]' else: cfg = '' self.emit( f""" {cfg} impl Located for {info.full_type_name} {{ fn range(&self) -> SourceRange {{ match self {{ {sum_match_arms} }} }} }} {cfg} impl LocatedMut for {info.full_type_name} {{ fn range_mut(&mut self) -> &mut SourceRange {{ match self {{ {sum_match_arms.replace('range()', 'range_mut()')} }} }} }} """.lstrip(), 0, ) def visitProduct(self, product, name, depth): info = self.type_info[name] self.emit_type_alias(info) self.emit_located_impl(info) def emit_type_alias(self, info): generics = "" if info.is_simple else "::" self.emit( f"pub type {info.full_type_name} = crate::generic::{info.full_type_name}{generics};", 0, ) self.emit("", 0) def emit_located_impl(self, info): if not info.no_cfg(self.type_info): cfg = '#[cfg(feature = "all-nodes-with-ranges")]' else: cfg = '' self.emit( f""" {cfg} impl Located for {info.full_type_name} {{ fn range(&self) -> SourceRange {{ self.range }} }} {cfg} impl LocatedMut for {info.full_type_name} {{ fn range_mut(&mut self) -> &mut SourceRange {{ &mut self.range }} }} """, 0, ) class ToPyo3AstVisitor(EmitVisitor): """Visitor to generate type-defs for AST.""" def __init__(self, namespace, *args, **kw): super().__init__(*args, **kw) self.namespace = namespace @property def generics(self): if self.namespace == "ranged": return "" elif self.namespace == "located": return "" else: assert False, self.namespace def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) def visitType(self, type): self.visit(type.value, type) def visitProduct(self, product, type): info = self.type_info[type.name] rust_name = info.full_type_name self.emit_to_pyo3_with_fields(product, type, rust_name) def visitSum(self, sum, type): info = self.type_info[type.name] rust_name = info.full_type_name simple = is_simple(sum) if is_simple(sum): return self.emit( f""" impl ToPyAst for ast::{rust_name}{self.generics} {{ #[inline] fn to_py_ast<'py>(&self, {"_" if simple else ""}py: Python<'py>) -> PyResult<&'py PyAny> {{ let instance = match &self {{ """, 0, ) for cons in sum.types: self.emit( f"ast::{rust_name}::{cons.name}(cons) => cons.to_py_ast(py)?,", 1, ) self.emit( """ }; Ok(instance) } } """, 0, ) for cons in sum.types: self.visit(cons, type) def visitConstructor(self, cons, type): parent = rust_type_name(type.name) self.emit_to_pyo3_with_fields(cons, type, f"{parent}{cons.name}") def emit_to_pyo3_with_fields(self, cons, type, name): type_info = self.type_info[type.name] self.emit( f""" impl ToPyAst for ast::{name}{self.generics} {{ #[inline] fn to_py_ast<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {{ let cache = Self::py_type_cache().get().unwrap(); """, 0, ) if cons.fields: field_names = ", ".join(rust_field(f.name) for f in cons.fields) if not type_info.is_simple: field_names += ", range: _range" self.emit( f"let Self {{ {field_names} }} = self;", 1, ) self.emit( """ let instance = Py::::as_ref(&cache.0, py).call1(( """, 1, ) for field in cons.fields: if field.type == "constant": self.emit( f"constant_to_object({rust_field(field.name)}, py),", 3, ) continue if field.type == "int": if field.name == "level": assert field.opt self.emit( f"{rust_field(field.name)}.map_or_else(|| py.None(), |level| level.to_u32().to_object(py)),", 3, ) continue if field.name == "lineno": self.emit( f"{rust_field(field.name)}.to_u32().to_object(py),", 3, ) continue self.emit( f"{rust_field(field.name)}.to_py_ast(py)?,", 3, ) self.emit( "))?;", 0, ) else: self.emit( "let Self { range: _range } = self;", 1, ) self.emit( """let instance = Py::::as_ref(&cache.0, py).call0()?;""", 1, ) if type.value.attributes and self.namespace == "located": self.emit( """ let cache = ast_cache(); instance.setattr(cache.lineno.as_ref(py), _range.start.row.get())?; instance.setattr(cache.col_offset.as_ref(py), _range.start.column.get())?; if let Some(end) = _range.end { instance.setattr(cache.end_lineno.as_ref(py), end.row.get())?; instance.setattr(cache.end_col_offset.as_ref(py), end.column.get())?; } """, 0, ) self.emit( """ Ok(instance) } } """, 0, ) class Pyo3StructVisitor(EmitVisitor): """Visitor to generate type-defs for AST.""" def __init__(self, namespace, *args, **kw): self.namespace = namespace self.borrow = True super().__init__(*args, **kw) @property def generics(self): if self.namespace == "ranged": return "" elif self.namespace == "located": return "" else: assert False, self.namespace @property def module_name(self): name = f"rustpython_ast.{self.namespace}" return name @property def ref_def(self): return "&'static " if self.borrow else "" @property def ref(self): return "&" if self.borrow else "" def emit_class(self, info, simple, base="super::Ast"): inner_name = info.full_type_name rust_name = self.type_info[info.custom.name].full_type_name if simple: generics = "" else: generics = self.generics if info.is_sum: subclass = ", subclass" body = "" into = f"{rust_name}" else: subclass = "" body = f"(pub {self.ref_def} ast::{inner_name}{generics})" into = f"{rust_name}(node)" self.emit( f""" #[pyclass(module="{self.module_name}", name="_{info.name}", extends={base}, frozen{subclass})] #[derive(Clone, Debug)] pub struct {rust_name} {body}; impl From<{self.ref_def} ast::{inner_name}{generics}> for {rust_name} {{ fn from({"" if body else "_"}node: {self.ref_def} ast::{inner_name}{generics}) -> Self {{ {into} }} }} """, 0, ) if subclass: self.emit( f""" #[pymethods] impl {rust_name} {{ #[new] fn new() -> PyClassInitializer {{ PyClassInitializer::from(Ast) .add_subclass(Self) }} }} impl ToPyObject for {rust_name} {{ fn to_object(&self, py: Python) -> PyObject {{ let initializer = Self::new(); Py::new(py, initializer).unwrap().into_py(py) }} }} """, 0, ) else: if base != "super::Ast": add_subclass = f".add_subclass({base})" else: add_subclass = "" self.emit( f""" impl ToPyObject for {rust_name} {{ fn to_object(&self, py: Python) -> PyObject {{ let initializer = PyClassInitializer::from(Ast) {add_subclass} .add_subclass(self.clone()); Py::new(py, initializer).unwrap().into_py(py) }} }} """, 0, ) if not subclass: self.emit_wrapper(info) def emit_getter(self, owner, type_name): self.emit( f""" #[pymethods] impl {type_name} {{ """, 0, ) for field in owner.fields: self.emit( f""" #[getter] #[inline] fn get_{field.name}(&self, py: Python) -> PyResult {{ self.0.{rust_field(field.name)}.to_py_wrapper(py) }} """, 3, ) self.emit( """ } """, 0, ) def emit_getattr(self, owner, type_name): self.emit( f""" #[pymethods] impl {type_name} {{ fn __getattr__(&self, py: Python, key: &str) -> PyResult {{ let object: Py = match key {{ """, 0, ) for field in owner.fields: self.emit( f'"{field.name}" => self.0.{rust_field(field.name)}.to_py_wrapper(py)?,', 3, ) self.emit( """ _ => todo!(), }; Ok(object) } } """, 0, ) def emit_wrapper(self, info): inner_name = info.full_type_name rust_name = self.type_info[info.custom.name].full_type_name self.emit( f""" impl ToPyWrapper for ast::{inner_name}{self.generics} {{ #[inline] fn to_py_wrapper(&'static self, py: Python) -> PyResult> {{ Ok({rust_name}(self).to_object(py)) }} }} """, 0, ) def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) def visitType(self, type, depth=0): self.visit(type.value, type, depth) def visitSum(self, sum, type, depth=0): info = self.type_info[type.name] rust_name = rust_type_name(type.name) simple = is_simple(sum) self.emit_class(info, simple) if not simple: self.emit( f""" impl ToPyWrapper for ast::{rust_name}{self.generics} {{ #[inline] fn to_py_wrapper(&'static self, py: Python) -> PyResult> {{ match &self {{ """, 0, ) for cons in sum.types: self.emit(f"Self::{cons.name}(cons) => cons.to_py_wrapper(py),", 3) self.emit( """ } } } """, 0, ) for cons in sum.types: self.visit(cons, rust_name, simple, depth + 1) def visitProduct(self, product, type, depth=0): info = self.type_info[type.name] rust_name = rust_type_name(type.name) self.emit_class(info, False) if self.borrow: self.emit_getter(product, rust_name) def visitConstructor(self, cons, parent, simple, depth): if simple: self.emit( f""" #[pyclass(module="{self.module_name}", name="_{cons.name}", extends={parent})] pub struct {parent}{cons.name}; impl ToPyObject for {parent}{cons.name} {{ fn to_object(&self, py: Python) -> PyObject {{ let initializer = PyClassInitializer::from(Ast) .add_subclass({parent}) .add_subclass(Self); Py::new(py, initializer).unwrap().into_py(py) }} }} """, depth, ) else: info = self.type_info[cons.name] self.emit_class( info, simple=False, base=parent, ) if self.borrow: self.emit_getter(cons, f"{parent}{cons.name}") class Pyo3PymoduleVisitor(EmitVisitor): def __init__(self, namespace, *args, **kw): self.namespace = namespace super().__init__(*args, **kw) def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitProduct(self, product, name, depth=0): info = self.type_info[name] self.emit_fields(info, False) def visitSum(self, sum, name, depth): info = self.type_info[name] simple = is_simple(sum) self.emit_fields(info, True) for cons in sum.types: self.visit(cons, name, simple, depth) def visitConstructor(self, cons, parent, simple, depth): info = self.type_info[cons.name] self.emit_fields(info, simple) def emit_fields(self, info, simple): inner_name = info.full_type_name rust_name = self.type_info[info.custom.name].full_type_name self.emit(f"super::init_type::<{rust_name}, ast::{inner_name}>(py, m)?;", 1) class StdlibClassDefVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): # info = self.type_info[self.type_info[name].custom.name] info = self.type_info[name] struct_name = "Node" + info.full_type_name self.emit( f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "NodeAst")]', depth, ) self.emit(f"struct {struct_name};", depth) self.emit("#[pyclass(flags(HAS_DICT, BASETYPE))]", depth) self.emit(f"impl {struct_name} {{}}", depth) for cons in sum.types: self.visit(cons, sum.attributes, struct_name, depth) def visitConstructor(self, cons, attrs, base, depth): self.gen_class_def(cons.name, cons.fields, attrs, depth, base) def visitProduct(self, product, name, depth): self.gen_class_def(name, product.fields, product.attributes, depth) def gen_class_def(self, name, fields, attrs, depth, base=None): info = self.type_info[self.type_info[name].custom.name] if base is None: base = "NodeAst" struct_name = "Node" + info.full_type_name else: struct_name = "Node" + info.full_type_name self.emit( f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = {json.dumps(base)})]', depth, ) self.emit(f"struct {struct_name};", depth) self.emit("#[pyclass(flags(HAS_DICT, BASETYPE))]", depth) self.emit(f"impl {struct_name} {{", depth) self.emit("#[extend_class]", depth + 1) self.emit( "fn extend_class_with_fields(ctx: &Context, class: &'static Py) {", depth + 1, ) fields = ",".join( f"ctx.new_str(ascii!({json.dumps(f.name)})).into()" for f in fields ) self.emit( f"class.set_attr(identifier!(ctx, _fields), ctx.new_tuple(vec![{fields}]).into());", depth + 2, ) attrs = ",".join( f"ctx.new_str(ascii!({json.dumps(attr.name)})).into()" for attr in attrs ) self.emit( f"class.set_attr(identifier!(ctx, _attributes), ctx.new_list(vec![{attrs}]).into());", depth + 2, ) self.emit("}", depth + 1) self.emit("}", depth) class StdlibExtendModuleVisitor(EmitVisitor): def visitModule(self, mod): depth = 0 self.emit( "pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py) {", depth, ) self.emit("extend_module!(vm, module, {", depth + 1) for dfn in mod.dfns: self.visit(dfn, depth + 2) self.emit("})", depth + 1) self.emit("}", depth) def visitType(self, type, depth): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): rust_name = rust_type_name(name) self.emit(f"{json.dumps(name)} => Node{rust_name}::make_class(&vm.ctx),", depth) for cons in sum.types: self.visit(cons, depth, rust_name) def visitConstructor(self, cons, depth, rust_name): self.gen_extension(cons.name, depth, rust_name) def visitProduct(self, product, name, depth): self.gen_extension(name, depth) def gen_extension(self, name, depth, base=""): rust_name = rust_type_name(name) self.emit( f"{json.dumps(name)} => Node{base}{rust_name}::make_class(&vm.ctx),", depth ) class StdlibTraitImplVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): info = self.type_info[name] rust_name = info.full_type_name self.emit("// sum", depth) self.emit(f"impl Node for ast::located::{rust_name} {{", depth) self.emit( "fn ast_to_object(self, vm: &VirtualMachine) -> PyObjectRef {", depth + 1 ) simple = is_simple(sum) if simple: self.emit("let node_type = match self {", depth + 2) for cons in sum.types: self.emit( f"ast::located::{rust_name}::{cons.name} => Node{rust_name}{cons.name}::static_type(),", depth, ) self.emit("};", depth + 3) self.emit( "NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()", depth + 2, ) else: self.emit("match self {", depth + 2) for cons in sum.types: self.emit( f"ast::located::{rust_name}::{cons.name}(cons) => cons.ast_to_object(vm),", depth + 3, ) self.emit("}", depth + 2) self.emit("}", depth + 1) self.emit( "fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult {", depth + 1, ) self.gen_sum_from_object(sum, name, rust_name, depth + 2) self.emit("}", depth + 1) self.emit("}", depth) if not is_simple(sum): for cons in sum.types: self.visit(cons, sum, rust_name, depth) def visitConstructor(self, cons, sum, sum_rust_name, depth): rust_name = rust_type_name(cons.name) self.emit("// constructor", depth) self.emit(f"impl Node for ast::located::{sum_rust_name}{rust_name} {{", depth) fields_pattern = self.make_pattern(cons.fields) self.emit( "fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1 ) self.emit( f"let ast::located::{sum_rust_name}{rust_name} {{ {fields_pattern} }} = self;", depth, ) self.make_node(cons.name, sum, cons.fields, depth + 2, sum_rust_name) self.emit("}", depth + 1) self.emit( "fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult {", depth + 1, ) self.gen_product_from_object( cons, cons.name, f"{sum_rust_name}{rust_name}", sum.attributes, depth + 2 ) self.emit("}", depth + 1) self.emit("}", depth + 1) def visitProduct(self, product, name, depth): info = self.type_info[name] struct_name = info.full_type_name self.emit("// product", depth) self.emit(f"impl Node for ast::located::{struct_name} {{", depth) self.emit( "fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1 ) fields_pattern = self.make_pattern(product.fields) self.emit( f"let ast::located::{struct_name} {{ {fields_pattern} }} = self;", depth + 2, ) self.make_node(name, product, product.fields, depth + 2) self.emit("}", depth + 1) self.emit( "fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult {", depth + 1, ) self.gen_product_from_object( product, name, struct_name, product.attributes, depth + 2 ) self.emit("}", depth + 1) self.emit("}", depth) def make_node(self, variant, owner, fields, depth, base=""): rust_variant = rust_type_name(variant) self.emit( f"let node = NodeAst.into_ref_with_type(_vm, Node{base}{rust_variant}::static_type().to_owned()).unwrap();", depth, ) if fields or owner.attributes: self.emit("let dict = node.as_object().dict().unwrap();", depth) for f in fields: self.emit( f"dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();", depth, ) if owner.attributes: self.emit("node_add_location(&dict, _range, _vm);", depth) self.emit("node.into()", depth) def make_pattern(self, fields): return "".join(f"{rust_field(f.name)}," for f in fields) + "range: _range" def gen_sum_from_object(self, sum, sum_name, rust_name, depth): # if sum.attributes: # self.extract_location(sum_name, depth) self.emit("let _cls = _object.class();", depth) self.emit("Ok(", depth) for cons in sum.types: self.emit( f"if _cls.is(Node{rust_name}{cons.name}::static_type()) {{", depth ) self.emit(f"ast::located::{rust_name}::{cons.name}", depth + 1) if not is_simple(sum): self.emit( f"(ast::located::{rust_name}{cons.name}::ast_from_object(_vm, _object)?)", depth + 1, ) self.emit("} else", depth) self.emit("{", depth) msg = f'format!("expected some sort of {sum_name}, but got {{}}",_object.repr(_vm)?)' self.emit(f"return Err(_vm.new_type_error({msg}));", depth + 1) self.emit("})", depth) def gen_product_from_object( self, product, product_name, struct_name, has_attributes, depth ): self.emit("Ok(", depth) self.gen_construction( struct_name, product, product_name, has_attributes, depth + 1 ) self.emit(")", depth) def gen_construction_fields(self, cons, name, depth): for field in cons.fields: self.emit( f"{rust_field(field.name)}: {self.decode_field(field, name)},", depth + 1, ) def gen_construction(self, cons_path, cons, name, attributes, depth): self.emit(f"ast::located::{cons_path} {{", depth) self.gen_construction_fields(cons, name, depth + 1) if attributes: self.emit(f'range: range_from_object(_vm, _object, "{name}")?,', depth + 1) else: self.emit("range: Default::default(),", depth + 1) self.emit("}", depth) def extract_location(self, typename, depth): row = self.decode_field(asdl.Field("int", "lineno"), typename) column = self.decode_field(asdl.Field("int", "col_offset"), typename) self.emit( f""" let _location = {{ let row = {row}; let column = {column}; try_location(row, column) }}; """, depth, ) def decode_field(self, field, typename): name = json.dumps(field.name) if field.opt and not field.seq: return f"get_node_field_opt(_vm, &_object, {name})?.map(|obj| Node::ast_from_object(_vm, obj)).transpose()?" else: return f"Node::ast_from_object(_vm, get_node_field(_vm, &_object, {name}, {json.dumps(typename)})?)?" class ChainOfVisitors: def __init__(self, *visitors): self.visitors = visitors def visit(self, object): for v in self.visitors: v.visit(object) v.emit("", 0) def write_ast_def(mod, type_info, f): f.write("use crate::text_size::TextRange;") StructVisitor(f, type_info).visit(mod) def write_fold_def(mod, type_info, f): FoldModuleVisitor(f, type_info).visit(mod) def write_visitor_def(mod, type_info, f): VisitorModuleVisitor(f, type_info).visit(mod) def write_ranged_def(mod, type_info, f): RangedDefVisitor(f, type_info).visit(mod) def write_located_def(mod, type_info, f): LocatedDefVisitor(f, type_info).visit(mod) def write_pyo3_node(type_info, f): def write(info: TypeInfo, rust_name: str): if info.is_simple: generics = "" else: generics = "" f.write( f""" impl{generics} PyNode for ast::{rust_name}{generics} {{ #[inline] fn py_type_cache() -> &'static OnceCell<(Py, Py)> {{ static PY_TYPE: OnceCell<(Py, Py)> = OnceCell::new(); &PY_TYPE }} }} """, ) for type_name, info in type_info.items(): rust_name = info.full_type_name if info.is_custom: if type_name != info.type.name: rust_name = "Python" + rust_name else: continue write(info, rust_name) def write_to_pyo3(mod, type_info, f): write_pyo3_node(type_info, f) write_to_pyo3_simple(type_info, f) for namespace in ("ranged", "located"): ToPyo3AstVisitor(namespace, f, type_info).visit(mod) f.write( """ fn init_types(py: Python) -> PyResult<()> { let ast_module = PyModule::import(py, "_ast")?; """ ) for info in type_info.values(): if info.is_custom: continue rust_name = info.full_type_name f.write(f"cache_py_type::(ast_module)?;\n") f.write("Ok(())\n}") def write_to_pyo3_simple(type_info, f): for type_info in type_info.values(): if not type_info.is_sum: continue if not type_info.is_simple: continue rust_name = type_info.full_type_name f.write( f""" impl ToPyAst for ast::{rust_name} {{ #[inline] fn to_py_ast<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {{ let cell = match &self {{ """, ) for cons in type_info.type.value.types: f.write( f"""ast::{rust_name}::{cons.name} => ast::{rust_name}{cons.name}::py_type_cache(),""", ) f.write( """ }; Ok(Py::::as_ref(&cell.get().unwrap().1, py)) } } """, ) def write_pyo3_wrapper(mod, type_info, namespace, f): Pyo3StructVisitor(namespace, f, type_info).visit(mod) if namespace == "located": for info in type_info.values(): if not info.is_simple or not info.is_sum: continue rust_name = info.full_type_name inner_name = type_info[info.custom.name].full_type_name f.write( f""" impl ToPyWrapper for ast::{inner_name} {{ #[inline] fn to_py_wrapper(&self, py: Python) -> PyResult> {{ match &self {{ """, ) for cons in info.type.value.types: f.write( f"Self::{cons.name} => Ok({rust_name}{cons.name}.to_object(py)),", ) f.write( """ } } } """, ) for cons in info.type.value.types: f.write( f""" impl ToPyWrapper for ast::{rust_name}{cons.name} {{ #[inline] fn to_py_wrapper(&self, py: Python) -> PyResult> {{ Ok({rust_name}{cons.name}.to_object(py)) }} }} """ ) f.write( """ pub fn add_to_module(py: Python, m: &PyModule) -> PyResult<()> { super::init_module(py, m)?; """ ) Pyo3PymoduleVisitor(namespace, f, type_info).visit(mod) f.write("Ok(())\n}") def write_parse_def(mod, type_info, f): for info in type_info.values(): if info.enum_name not in ["expr", "stmt"]: continue type_name = rust_type_name(info.enum_name) cons_name = rust_type_name(info.name) f.write(f""" impl Parse for ast::{info.full_type_name} {{ fn lex_starts_at( source: &str, offset: TextSize, ) -> SoftKeywordTransformer> {{ ast::{type_name}::lex_starts_at(source, offset) }} fn parse_tokens( lxr: impl IntoIterator, source_path: &str, ) -> Result {{ let node = ast::{type_name}::parse_tokens(lxr, source_path)?; match node {{ ast::{type_name}::{cons_name}(node) => Ok(node), node => Err(ParseError {{ error: ParseErrorType::InvalidToken, offset: node.range().start(), source_path: source_path.to_owned(), }}), }} }} }} """) def write_ast_mod(mod, type_info, f): f.write( """ #![allow(clippy::all)] use super::*; use crate::common::ascii; """ ) c = ChainOfVisitors( StdlibClassDefVisitor(f, type_info), StdlibTraitImplVisitor(f, type_info), StdlibExtendModuleVisitor(f, type_info), ) c.visit(mod) def main( input_filename, ast_dir, parser_dir, ast_pyo3_dir, module_filename, dump_module=False, ): auto_gen_msg = AUTO_GEN_MESSAGE.format("/".join(Path(__file__).parts[-2:])) mod = asdl.parse(input_filename) if dump_module: print("Parsed Module:") print(mod) if not asdl.check(mod): sys.exit(1) type_info = {} FindUserDataTypesVisitor(type_info).visit(mod) from functools import partial as p for filename, write in [ ("generic", p(write_ast_def, mod, type_info)), ("fold", p(write_fold_def, mod, type_info)), ("ranged", p(write_ranged_def, mod, type_info)), ("located", p(write_located_def, mod, type_info)), ("visitor", p(write_visitor_def, mod, type_info)), ]: with (ast_dir / f"{filename}.rs").open("w") as f: f.write(auto_gen_msg) write(f) for filename, write in [ ("parse", p(write_parse_def, mod, type_info)), ]: with (parser_dir / f"{filename}.rs").open("w") as f: f.write(auto_gen_msg) write(f) for filename, write in [ ("to_py_ast", p(write_to_pyo3, mod, type_info)), ("wrapper_located", p(write_pyo3_wrapper, mod, type_info, "located")), ("wrapper_ranged", p(write_pyo3_wrapper, mod, type_info, "ranged")), ]: with (ast_pyo3_dir / f"{filename}.rs").open("w") as f: f.write(auto_gen_msg) write(f) with module_filename.open("w") as module_file: module_file.write(auto_gen_msg) write_ast_mod(mod, type_info, module_file) print(f"{ast_dir}, {module_filename} regenerated.") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("input_file", type=Path) parser.add_argument("-A", "--ast-dir", type=Path, required=True) parser.add_argument("-P", "--parser-dir", type=Path, required=True) parser.add_argument("-O", "--ast-pyo3-dir", type=Path, required=True) parser.add_argument("-M", "--module-file", type=Path, required=True) parser.add_argument("-d", "--dump-module", action="store_true") args = parser.parse_args() main( args.input_file, args.ast_dir, args.parser_dir, args.ast_pyo3_dir, args.module_file, args.dump_module, )