RustPython-Parser/ast/asdl_rs.py
2023-05-09 20:34:48 +09:00

898 lines
30 KiB
Python
Executable file

#! /usr/bin/env python
"""Generate Rust code from an ASDL description."""
import sys
import json
import textwrap
from argparse import ArgumentParser
from pathlib import Path
from typing import Optional, Dict
from attr import dataclass
import asdl
TABSIZE = 4
AUTOGEN_MESSAGE = "// File automatically generated by {}.\n"
builtin_type_mapping = {
"identifier": "Ident",
"string": "String",
"int": "usize",
"constant": "Constant",
}
assert builtin_type_mapping.keys() == asdl.builtin_types
def get_rust_type(name):
"""Return a string for the C name of the type.
This function special cases the default types provided by asdl.
"""
if name in asdl.builtin_types:
return builtin_type_mapping[name]
elif name.islower():
return "".join(part.capitalize() for part in name.split("_"))
else:
return name
def is_simple(sum):
"""Return True if a sum is a simple.
A sum is simple if its types have no fields, e.g.
unaryop = Invert | Not | UAdd | USub
"""
for t in sum.types:
if t.fields:
return False
return True
def asdl_of(name, obj):
if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor):
fields = ", ".join(map(str, obj.fields))
if fields:
fields = "({})".format(fields)
return "{}{}".format(name, fields)
else:
if is_simple(obj):
types = " | ".join(type.name for type in obj.types)
else:
sep = "\n{}| ".format(" " * (len(name) + 1))
types = sep.join(asdl_of(type.name, type) for type in obj.types)
return "{} = {}".format(name, types)
class TypeInfo:
name: str
enum_name: Optional[str]
has_userdata: Optional[bool]
has_attributes: bool
children: set
boxed: bool
product: bool
has_expr: bool = False
def __init__(self, name):
self.name = name
self.enum_name = None
self.has_userdata = None
self.has_attributes = False
self.children = set()
self.boxed = False
self.product = False
self.product_has_expr = False
def __repr__(self):
return f"<TypeInfo: {self.name}>"
@property
def rust_name(self):
return get_rust_type(self.name)
@property
def sum_name(self):
if self.enum_name is None:
return self.name
else:
return f"{self.enum_name}_{self.name}"
@property
def rust_sum_name(self):
rust_name = get_rust_type(self.name)
if self.enum_name is None:
return rust_name
else:
name = get_rust_type(self.enum_name) + rust_name
return name
@property
def rust_suffix(self):
if self.product:
if self.has_attributes:
return "Data"
else:
return ""
else:
if self.has_attributes:
return "Kind"
else:
return ""
def determine_userdata(self, typeinfo, stack):
if self.name in stack:
return None
stack.add(self.name)
for child, child_seq in self.children:
if child in asdl.builtin_types:
continue
childinfo = typeinfo[child]
child_has_userdata = childinfo.determine_userdata(typeinfo, stack)
if self.has_userdata is None and child_has_userdata is True:
self.has_userdata = True
stack.remove(self.name)
return self.has_userdata
class TypeInfoMixin:
typeinfo: Dict[str, TypeInfo]
def has_userdata(self, typ):
return self.typeinfo[typ].has_userdata
def get_generics(self, typ, *generics):
if self.has_userdata(typ):
return [f"<{g}>" for g in generics]
else:
return ["" for g in generics]
class EmitVisitor(asdl.VisitorBase, TypeInfoMixin):
"""Visit that emits lines"""
def __init__(self, file, typeinfo):
self.file = file
self.typeinfo = typeinfo
self.identifiers = set()
super(EmitVisitor, self).__init__()
def emit_identifier(self, name):
name = str(name)
if name in self.identifiers:
return
self.emit("_Py_IDENTIFIER(%s);" % name, 0)
self.identifiers.add(name)
def emit(self, line, depth):
if line:
line = (" " * TABSIZE * depth) + line
self.file.write(line + "\n")
class FindUserdataTypesVisitor(asdl.VisitorBase):
def __init__(self, typeinfo):
self.typeinfo = typeinfo
super().__init__()
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
stack = set()
for info in self.typeinfo.values():
info.determine_userdata(self.typeinfo, stack)
def visitType(self, type):
self.typeinfo[type.name] = TypeInfo(type.name)
self.visit(type.value, type.name)
def visitSum(self, sum, name):
info = self.typeinfo[name]
if is_simple(sum):
info.has_userdata = False
else:
for t in sum.types:
if not t.fields:
continue
t_info = TypeInfo(t.name)
t_info.enum_name = name
self.typeinfo[t.name] = t_info
self.add_children(t.name, t.fields)
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
# attributes means located, which has the `custom: U` field
info.has_userdata = True
info.has_attributes = True
for variant in sum.types:
self.add_children(name, variant.fields)
def visitProduct(self, product, name):
info = self.typeinfo[name]
if product.attributes:
# attributes means located, which has the `custom: U` field
info.has_userdata = True
info.has_attributes = True
info.has_expr = product_has_expr(product)
if len(product.fields) > 2:
info.boxed = True
info.product = True
self.add_children(name, product.fields)
def add_children(self, name, fields):
self.typeinfo[name].children.update((field.type, field.seq) for field in fields)
def rust_field(field_name):
if field_name == "type":
return "type_"
else:
return field_name
def product_has_expr(product):
return any(f.type != "identifier" for f in product.fields)
class StructVisitor(EmitVisitor):
"""Visitor to generate typedefs for AST."""
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.rust_type_defs = []
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
if is_simple(sum):
self.simple_sum(sum, name, depth)
else:
self.sum_with_constructors(sum, name, depth)
def emit_attrs(self, depth):
self.emit("#[derive(Clone, Debug, PartialEq)]", depth)
def simple_sum(self, sum, name, depth):
rustname = get_rust_type(name)
self.emit_attrs(depth)
self.emit(f"pub enum {rustname} {{", depth)
for variant in sum.types:
self.emit(f"{variant.name},", depth + 1)
self.emit("}", depth)
self.emit("", depth)
def sum_with_constructors(self, sum, name, depth):
typeinfo = self.typeinfo[name]
suffix = typeinfo.rust_suffix
rustname = get_rust_type(name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Attributed<>
for t in sum.types:
if not t.fields:
continue
self.sum_subtype_struct(typeinfo, t, rustname, depth)
generics, generics_applied = self.get_generics(name, "U = ()", "U")
self.emit_attrs(depth)
self.emit(f"pub enum {rustname}{suffix}{generics} {{", depth)
for t in sum.types:
if t.fields:
(t_generics_applied,) = self.get_generics(t.name, "U")
self.emit(
f"{t.name}({rustname}{t.name}{t_generics_applied}),", depth + 1
)
else:
self.emit(f"{t.name},", depth + 1)
self.emit("}", depth)
if typeinfo.has_attributes:
self.emit(
f"pub type {rustname}<U = ()> = Attributed<{rustname}{suffix}{generics_applied}, U>;",
depth,
)
self.emit("", depth)
def sum_subtype_struct(self, sum_typeinfo, t, rustname, depth):
self.emit_attrs(depth)
generics, generics_applied = self.get_generics(t.name, "U = ()", "U")
payload_name = f"{rustname}{t.name}"
self.emit(f"pub struct {payload_name}{generics} {{", depth)
for f in t.fields:
self.visit(f, sum_typeinfo, "pub ", depth + 1, t.name)
self.emit("}", depth)
self.emit(
textwrap.dedent(
f"""
impl{generics_applied} From<{payload_name}{generics_applied}> for {rustname}{sum_typeinfo.rust_suffix}{generics_applied} {{
fn from(payload: {payload_name}{generics_applied}) -> Self {{
{rustname}{sum_typeinfo.rust_suffix}::{t.name}(payload)
}}
}}
"""
),
depth,
)
def visitConstructor(self, cons, parent, depth):
if cons.fields:
self.emit(f"{cons.name} {{", depth)
for f in cons.fields:
self.visit(f, parent, "", depth + 1, cons.name)
self.emit("},", depth)
else:
self.emit(f"{cons.name},", depth)
def visitField(self, field, parent, vis, depth, constructor=None):
typ = get_rust_type(field.type)
fieldtype = self.typeinfo.get(field.type)
if fieldtype and fieldtype.has_userdata:
typ = f"{typ}<U>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if (
fieldtype
and fieldtype.boxed
and (not (parent.product or field.seq) or field.opt)
):
typ = f"Box<{typ}>"
if field.opt or (
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict"
and field.name == "keys"
):
typ = f"Option<{typ}>"
if field.seq:
typ = f"Vec<{typ}>"
name = rust_field(field.name)
self.emit(f"{vis}{name}: {typ},", depth)
def visitProduct(self, product, name, depth):
typeinfo = self.typeinfo[name]
generics, generics_applied = self.get_generics(name, "U = ()", "U")
dataname = rustname = get_rust_type(name)
if product.attributes:
dataname = rustname + "Data"
self.emit_attrs(depth)
has_expr = product_has_expr(product)
if has_expr:
datadef = f"{dataname}{generics}"
else:
datadef = dataname
self.emit(f"pub struct {datadef} {{", depth)
for f in product.fields:
self.visit(f, typeinfo, "pub ", depth + 1)
self.emit("}", depth)
if product.attributes:
# attributes should just be location info
if not has_expr:
generics_applied = ""
self.emit(
f"pub type {rustname}<U = ()> = Attributed<{dataname}{generics_applied}, U>;",
depth,
)
self.emit("", depth)
class FoldTraitDefVisitor(EmitVisitor):
def visitModule(self, mod, depth):
self.emit("pub trait Fold<U> {", depth)
self.emit("type TargetU;", depth + 1)
self.emit("type Error;", depth + 1)
self.emit(
"fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error>;",
depth + 1,
)
self.emit(
"fn map_located<T>(&mut self, located: Attributed<T, U>) -> Result<Attributed<T, Self::TargetU>, Self::Error> { let custom = self.map_user(located.custom)?; Ok(Attributed { range: located.range, custom, node: located.node }) }",
depth + 1,
)
for dfn in mod.dfns:
self.visit(dfn, depth + 2)
self.emit("}", depth)
def visitType(self, type, depth):
name = type.name
apply_u, apply_target_u = self.get_generics(name, "U", "Self::TargetU")
enumname = get_rust_type(name)
self.emit(
f"fn fold_{name}(&mut self, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, Self::Error> {{",
depth,
)
self.emit(f"fold_{name}(self, node)", depth + 1)
self.emit("}", depth)
class FoldImplVisitor(EmitVisitor):
def visitModule(self, mod, depth):
self.emit(
"fn fold_located<U, F: Fold<U> + ?Sized, T, MT>(folder: &mut F, node: Attributed<T, U>, f: impl FnOnce(&mut F, T) -> Result<MT, F::Error>) -> Result<Attributed<MT, F::TargetU>, F::Error> {",
depth,
)
self.emit(
"let node = folder.map_located(node)?; Ok(Attributed { custom: node.custom, range: node.range, node: f(folder, node.node)? })",
depth + 1,
)
self.emit("}", depth)
for dfn in mod.dfns:
self.visit(dfn, depth)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
typeinfo = self.typeinfo[name]
apply_t, apply_u, apply_target_u = self.get_generics(
name, "T", "U", "F::TargetU"
)
enumname = get_rust_type(name)
self.emit(f"impl<T, U> Foldable<T, U> for {enumname}{apply_t} {{", depth)
self.emit(f"type Mapped = {enumname}{apply_u};", depth + 1)
self.emit(
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
depth + 1,
)
self.emit(f"folder.fold_{name}(self)", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, F::Error> {{",
depth,
)
if typeinfo.has_attributes:
self.emit("fold_located(folder, node, |folder, node| {", depth)
self.emit("match node {", depth + 1)
for cons in sum.types:
fields_pattern = self.make_pattern(
enumname, typeinfo.rust_suffix, cons.name, cons.fields
)
self.emit(
f"{fields_pattern[0]} {{ {fields_pattern[1]} }} {fields_pattern[2]} => {{",
depth + 2,
)
self.gen_construction(
fields_pattern[0], cons.fields, fields_pattern[2], depth + 3
)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
if typeinfo.has_attributes:
self.emit("})", depth)
self.emit("}", depth)
def visitProduct(self, product, name, depth):
apply_t, apply_u, apply_target_u = self.get_generics(
name, "T", "U", "F::TargetU"
)
structname = get_rust_type(name)
has_attributes = bool(product.attributes)
self.emit(f"impl<T, U> Foldable<T, U> for {structname}{apply_t} {{", depth)
self.emit(f"type Mapped = {structname}{apply_u};", depth + 1)
self.emit(
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
depth + 1,
)
self.emit(f"folder.fold_{name}(self)", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {structname}{apply_u}) -> Result<{structname}{apply_target_u}, F::Error> {{",
depth,
)
if has_attributes:
self.emit("fold_located(folder, node, |folder, node| {", depth)
rustname = structname + "Data"
else:
rustname = structname
fields_pattern = self.make_pattern(rustname, structname, None, product.fields)
self.emit(f"let {rustname} {{ {fields_pattern[1]} }} = node;", depth + 1)
self.gen_construction(rustname, product.fields, "", depth + 1)
if has_attributes:
self.emit("})", depth)
self.emit("}", depth)
def make_pattern(self, rustname, suffix, fieldname, fields):
if fields:
header = f"{rustname}{suffix}::{fieldname}({rustname}{fieldname}"
footer = ")"
else:
header = f"{rustname}{suffix}::{fieldname}"
footer = ""
body = ",".join(rust_field(f.name) for f in fields)
return header, body, footer
def gen_construction(self, header, fields, footer, depth):
self.emit(f"Ok({header} {{", depth)
for field in fields:
name = rust_field(field.name)
self.emit(f"{name}: Foldable::fold({name}, folder)?,", depth + 1)
self.emit(f"}}{footer})", depth)
class FoldModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit('#[cfg(feature = "fold")]', depth)
self.emit("pub mod fold {", depth)
self.emit("use super::*;", depth + 1)
self.emit("use crate::fold_helpers::Foldable;", depth + 1)
FoldTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
FoldImplVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
self.emit("}", depth)
class ClassDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
structname = "NodeKind" + get_rust_type(name)
self.emit(
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "AstNode")]',
depth,
)
self.emit(f"struct {structname};", depth)
self.emit("#[pyclass(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {structname} {{}}", depth)
for cons in sum.types:
self.visit(cons, sum.attributes, structname, depth)
def visitConstructor(self, cons, attrs, base, depth):
self.gen_classdef(cons.name, cons.fields, attrs, depth, base)
def visitProduct(self, product, name, depth):
self.gen_classdef(name, product.fields, product.attributes, depth)
def gen_classdef(self, name, fields, attrs, depth, base="AstNode"):
structname = "Node" + get_rust_type(name)
self.emit(
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = {json.dumps(base)})]',
depth,
)
self.emit(f"struct {structname};", depth)
self.emit("#[pyclass(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {structname} {{", depth)
self.emit("#[extend_class]", depth + 1)
self.emit(
"fn extend_class_with_fields(ctx: &Context, class: &'static Py<PyType>) {",
depth + 1,
)
fields = ",".join(
f"ctx.new_str(ascii!({json.dumps(f.name)})).into()" for f in fields
)
self.emit(
f"class.set_attr(identifier!(ctx, _fields), ctx.new_tuple(vec![{fields}]).into());",
depth + 2,
)
attrs = ",".join(
f"ctx.new_str(ascii!({json.dumps(attr.name)})).into()" for attr in attrs
)
self.emit(
f"class.set_attr(identifier!(ctx, _attributes), ctx.new_list(vec![{attrs}]).into());",
depth + 2,
)
self.emit("}", depth + 1)
self.emit("}", depth)
class ExtendModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit(
"pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py<PyModule>) {",
depth,
)
self.emit("extend_module!(vm, module, {", depth + 1)
for dfn in mod.dfns:
self.visit(dfn, depth + 2)
self.emit("})", depth + 1)
self.emit("}", depth)
def visitType(self, type, depth):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
rust_name = get_rust_type(name)
self.emit(
f"{json.dumps(name)} => NodeKind{rust_name}::make_class(&vm.ctx),", depth
)
for cons in sum.types:
self.visit(cons, depth)
def visitConstructor(self, cons, depth):
self.gen_extension(cons.name, depth)
def visitProduct(self, product, name, depth):
self.gen_extension(name, depth)
def gen_extension(self, name, depth):
rust_name = get_rust_type(name)
self.emit(f"{json.dumps(name)} => Node{rust_name}::make_class(&vm.ctx),", depth)
class TraitImplVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
rustname = enumname = get_rust_type(name)
if sum.attributes:
rustname = enumname + "Kind"
self.emit(f"impl NamedNode for ast::located::{rustname} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::located::{rustname} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
self.emit("match self {", depth + 2)
for variant in sum.types:
self.constructor_to_object(variant, enumname, rustname, depth + 3)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_sum_fromobj(sum, name, enumname, rustname, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def constructor_to_object(self, cons, enumname, rustname, depth):
self.emit(f"ast::located::{rustname}::{cons.name}", depth)
if cons.fields:
fields_pattern = self.make_pattern(cons.fields)
self.emit(
f"( ast::located::{enumname}{cons.name} {{ {fields_pattern} }} )", depth
)
self.emit(" => {", depth)
self.make_node(cons.name, cons.fields, depth + 1)
self.emit("}", depth)
def visitProduct(self, product, name, depth):
structname = get_rust_type(name)
if product.attributes:
structname += "Data"
self.emit(f"impl NamedNode for ast::located::{structname} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::located::{structname} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
fields_pattern = self.make_pattern(product.fields)
self.emit(
f"let ast::located::{structname} {{ {fields_pattern} }} = self;", depth + 2
)
self.make_node(name, product.fields, depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_product_fromobj(product, name, structname, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def make_node(self, variant, fields, depth):
rust_variant = get_rust_type(variant)
self.emit(
f"let _node = AstNode.into_ref_with_type(_vm, Node{rust_variant}::static_type().to_owned()).unwrap();",
depth,
)
if fields:
self.emit("let _dict = _node.as_object().dict().unwrap();", depth)
for f in fields:
self.emit(
f"_dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();",
depth,
)
self.emit("_node.into()", depth)
def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)
def gen_sum_fromobj(self, sum, sumname, enumname, rustname, depth):
if sum.attributes:
self.extract_location(sumname, depth)
self.emit("let _cls = _object.class();", depth)
self.emit("Ok(", depth)
for cons in sum.types:
self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth)
if cons.fields:
self.emit(
f"ast::located::{rustname}::{cons.name} (ast::located::{enumname}{cons.name} {{",
depth + 1,
)
self.gen_construction_fields(cons, sumname, depth + 1)
self.emit("})", depth + 1)
else:
self.emit(f"ast::located::{rustname}::{cons.name}", depth + 1)
self.emit("} else", depth)
self.emit("{", depth)
msg = f'format!("expected some sort of {sumname}, but got {{}}",_object.repr(_vm)?)'
self.emit(f"return Err(_vm.new_type_error({msg}));", depth + 1)
self.emit("})", depth)
def gen_product_fromobj(self, product, prodname, structname, depth):
if product.attributes:
self.extract_location(prodname, depth)
self.emit("Ok(", depth)
self.gen_construction(structname, product, prodname, depth + 1)
self.emit(")", depth)
def gen_construction_fields(self, cons, name, depth):
for field in cons.fields:
self.emit(
f"{rust_field(field.name)}: {self.decode_field(field, name)},",
depth + 1,
)
def gen_construction(self, cons_path, cons, name, depth):
self.emit(f"ast::located::{cons_path} {{", depth)
self.gen_construction_fields(cons, name, depth + 1)
self.emit("}", depth)
def extract_location(self, typename, depth):
row = self.decode_field(asdl.Field("int", "lineno"), typename)
column = self.decode_field(asdl.Field("int", "col_offset"), typename)
self.emit(f"""let _location = {{
let row = try_location_field({row}, _vm)?;
let column = try_location_field({column}, _vm)?;
SourceLocation {{ row, column }}
}};""", depth)
def decode_field(self, field, typename):
name = json.dumps(field.name)
if field.opt and not field.seq:
return f"get_node_field_opt(_vm, &_object, {name})?.map(|obj| Node::ast_from_object(_vm, obj)).transpose()?"
else:
return f"Node::ast_from_object(_vm, get_node_field(_vm, &_object, {name}, {json.dumps(typename)})?)?"
class ChainOfVisitors:
def __init__(self, *visitors):
self.visitors = visitors
def visit(self, object):
for v in self.visitors:
v.visit(object)
v.emit("", 0)
def write_generic_def(mod, typeinfo, f):
f.write(
textwrap.dedent(
"""
pub use crate::{Attributed, constant::*};
type Ident = String;
\n
"""
)
)
c = ChainOfVisitors(StructVisitor(f, typeinfo), FoldModuleVisitor(f, typeinfo))
c.visit(mod)
def write_located_def(typeinfo, f):
f.write(
textwrap.dedent(
"""
use crate::location::SourceRange;
pub type Located<T> = super::generic::Attributed<T, SourceRange>;
"""
)
)
for info in typeinfo.values():
if info.has_userdata:
generics = "::<SourceRange>"
else:
generics = ""
f.write(
f"pub type {info.rust_sum_name} = super::generic::{info.rust_sum_name}{generics};\n"
)
if info.rust_suffix:
if info.rust_suffix == "Data" and not info.has_expr:
generics = ""
f.write(
f"pub type {info.rust_sum_name}{info.rust_suffix} = super::generic::{info.rust_sum_name}{info.rust_suffix}{generics};\n"
)
def write_ast_mod(mod, typeinfo, f):
f.write(
textwrap.dedent(
"""
#![allow(clippy::all)]
use super::*;
use crate::common::ascii;
"""
)
)
c = ChainOfVisitors(
ClassDefVisitor(f, typeinfo),
TraitImplVisitor(f, typeinfo),
ExtendModuleVisitor(f, typeinfo),
)
c.visit(mod)
def main(
input_filename,
generic_filename,
located_filename,
module_filename,
dump_module=False,
):
auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
mod = asdl.parse(input_filename)
if dump_module:
print("Parsed Module:")
print(mod)
if not asdl.check(mod):
sys.exit(1)
typeinfo = {}
FindUserdataTypesVisitor(typeinfo).visit(mod)
with generic_filename.open("w") as generic_file, located_filename.open(
"w"
) as located_file:
generic_file.write(auto_gen_msg)
write_generic_def(mod, typeinfo, generic_file)
located_file.write(auto_gen_msg)
write_located_def(typeinfo, located_file)
with module_filename.open("w") as module_file:
module_file.write(auto_gen_msg)
write_ast_mod(mod, typeinfo, module_file)
print(f"{generic_filename}, {located_filename}, {module_filename} regenerated.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("input_file", type=Path)
parser.add_argument("-G", "--generic-file", type=Path, required=True)
parser.add_argument("-L", "--located-file", type=Path, required=True)
parser.add_argument("-M", "--module-file", type=Path, required=True)
parser.add_argument("-d", "--dump-module", action="store_true")
args = parser.parse_args()
main(
args.input_file,
args.generic_file,
args.located_file,
args.module_file,
args.dump_module,
)