RustPython-Parser/ast/asdl_rs.py
Jeong YunWon 1a07454dc7 ast::pyo3
2023-05-15 18:27:51 +09:00

1724 lines
54 KiB
Python
Executable file

# spell-checker:words dfn dfns
# ! /usr/bin/env python
"""Generate Rust code from an ASDL description."""
import sys
import json
import textwrap
import re
from argparse import ArgumentParser
from pathlib import Path
from typing import Optional, Dict
import asdl
TABSIZE = 4
AUTO_GEN_MESSAGE = "// File automatically generated by {}.\n\n"
BUILTIN_TYPE_NAMES = {
"identifier": "Identifier",
"string": "String",
"int": "Int",
"constant": "Constant",
}
assert BUILTIN_TYPE_NAMES.keys() == asdl.builtin_types
BUILTIN_INT_NAMES = {
"simple": "bool",
"is_async": "bool",
}
RUST_KEYWORDS = {"if", "while", "for", "return", "match", "try", "await", "yield"}
def rust_field_name(name):
name = rust_type_name(name)
return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
def rust_type_name(name):
"""Return a string for the C name of the type.
This function special cases the default types provided by asdl.
"""
if name in asdl.builtin_types:
builtin = BUILTIN_TYPE_NAMES[name]
return builtin
elif name.islower():
return "".join(part.capitalize() for part in name.split("_"))
else:
return name
def is_simple(sum):
"""Return True if a sum is a simple.
A sum is simple if its types have no fields, e.g.
unaryop = Invert | Not | UAdd | USub
"""
for t in sum.types:
if t.fields:
return False
return True
def asdl_of(name, obj):
if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor):
fields = ", ".join(map(str, obj.fields))
if fields:
fields = "({})".format(fields)
return "{}{}".format(name, fields)
else:
if is_simple(obj):
types = " | ".join(type.name for type in obj.types)
else:
sep = "\n{}| ".format(" " * (len(name) + 1))
types = sep.join(asdl_of(type.name, type) for type in obj.types)
return "{} = {}".format(name, types)
class TypeInfo:
name: str
enum_name: Optional[str]
has_user_data: Optional[bool]
has_attributes: bool
is_simple: bool
empty_field: bool
children: set
boxed: bool
product: bool
has_expr: bool = False
def __init__(self, name):
self.name = name
self.enum_name = None
self.has_user_data = None
self.has_attributes = False
self.is_simple = False
self.empty_field = False
self.children = set()
self.boxed = False
self.product = False
self.product_has_expr = False
def __repr__(self):
return f"<TypeInfo: {self.name}>"
def no_cfg(self, typeinfo):
if self.product:
return self.has_attributes
elif self.enum_name:
return typeinfo[self.enum_name].has_attributes
else:
return self.has_attributes
@property
def rust_name(self):
return rust_type_name(self.name)
@property
def sum_name(self):
if self.enum_name is None:
return self.name
else:
return f"{self.enum_name}_{self.name}"
@property
def rust_sum_name(self):
rust_name = rust_type_name(self.name)
if self.enum_name is None:
return rust_name
else:
name = rust_type_name(self.enum_name) + rust_name
return name
def determine_user_data(self, type_info, stack):
if self.name in stack:
return None
stack.add(self.name)
for child, child_seq in self.children:
if child in asdl.builtin_types:
continue
child_info = type_info[child]
child_has_user_data = child_info.determine_user_data(type_info, stack)
if self.has_user_data is None and child_has_user_data is True:
self.has_user_data = True
stack.remove(self.name)
return self.has_user_data
class TypeInfoMixin:
type_info: Dict[str, TypeInfo]
def has_user_data(self, typ):
return self.type_info[typ].has_user_data
def apply_generics(self, typ, *generics):
needs_generics = not self.type_info[typ].is_simple
if needs_generics:
return [f"<{g}>" for g in generics]
else:
return ["" for g in generics]
class EmitVisitor(asdl.VisitorBase, TypeInfoMixin):
"""Visit that emits lines"""
def __init__(self, file, type_info):
self.file = file
self.type_info = type_info
self.identifiers = set()
super(EmitVisitor, self).__init__()
def emit_identifier(self, name):
name = str(name)
if name in self.identifiers:
return
self.emit("_Py_IDENTIFIER(%s);" % name, 0)
self.identifiers.add(name)
def emit(self, line, depth):
if line:
line = (" " * TABSIZE * depth) + line
self.file.write(line + "\n")
class FindUserDataTypesVisitor(asdl.VisitorBase):
def __init__(self, type_info):
self.type_info = type_info
super().__init__()
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
stack = set()
for info in self.type_info.values():
info.determine_user_data(self.type_info, stack)
def visitType(self, type):
self.type_info[type.name] = TypeInfo(type.name)
self.visit(type.value, type.name)
def visitSum(self, sum, name):
info = self.type_info[name]
if is_simple(sum):
info.has_user_data = False
info.is_simple = True
else:
for t in sum.types:
t_info = TypeInfo(t.name)
t_info.enum_name = name
t_info.empty_field = not t.fields
self.type_info[t.name] = t_info
self.add_children(t.name, t.fields)
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
# attributes means located, which has the `range: R` field
info.has_user_data = True
info.has_attributes = True
for variant in sum.types:
self.add_children(name, variant.fields)
def visitProduct(self, product, name):
info = self.type_info[name]
if product.attributes:
# attributes means located, which has the `range: R` field
info.has_user_data = True
info.has_attributes = True
info.has_expr = product_has_expr(product)
if len(product.fields) > 2:
info.boxed = True
info.product = True
self.add_children(name, product.fields)
def add_children(self, name, fields):
self.type_info[name].children.update(
(field.type, field.seq) for field in fields
)
def rust_field(field_name):
if field_name == "type":
return "type_"
else:
return field_name
def product_has_expr(product):
return any(f.type != "identifier" for f in product.fields)
class StructVisitor(EmitVisitor):
"""Visitor to generate type-defs for AST."""
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
if is_simple(sum):
self.simple_sum(sum, name, depth)
else:
self.sum_with_constructors(sum, name, depth)
(generics_applied,) = self.apply_generics(name, "R")
self.emit(
f"""
impl{generics_applied} Node for {rust_type_name(name)}{generics_applied} {{
const NAME: &'static str = "{name}";
const FIELD_NAMES: &'static [&'static str] = &[];
}}
""",
depth,
)
def emit_attrs(self, depth):
self.emit("#[derive(Clone, Debug, PartialEq)]", depth)
def emit_range(self, has_attributes, depth):
if has_attributes:
self.emit("pub range: R,", depth + 1)
else:
self.emit("pub range: OptionalRange<R>,", depth + 1)
def simple_sum(self, sum, name, depth):
rust_name = rust_type_name(name)
self.emit_attrs(depth)
self.emit("#[derive(is_macro::Is, Copy, Hash, Eq)]", depth)
self.emit(f"pub enum {rust_name} {{", depth)
for variant in sum.types:
self.emit(f"{variant.name},", depth + 1)
self.emit("}", depth)
self.emit("", depth)
def sum_with_constructors(self, sum, name, depth):
type_info = self.type_info[name]
rust_name = rust_type_name(name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Attributed<>
for t in sum.types:
self.sum_subtype_struct(type_info, t, rust_name, depth)
self.emit_attrs(depth)
self.emit("#[derive(is_macro::Is)]", depth)
self.emit(f"pub enum {rust_name}<R = TextRange> {{", depth)
needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types)
for t in sum.types:
if needs_escape:
self.emit(
f'#[is(name = "{rust_field_name(t.name)}_{rust_name.lower()}")]',
depth + 1,
)
self.emit(f"{t.name}({rust_name}{t.name}<R>),", depth + 1)
self.emit("}", depth)
self.emit("", depth)
def sum_subtype_struct(self, sum_type_info, t, rust_name, depth):
self.emit_attrs(depth)
payload_name = f"{rust_name}{t.name}"
self.emit(f"pub struct {payload_name}<R = TextRange> {{", depth)
self.emit_range(sum_type_info.has_attributes, depth)
for f in t.fields:
self.visit(f, sum_type_info, "pub ", depth + 1, t.name)
assert sum_type_info.has_attributes == self.type_info[t.name].no_cfg(
self.type_info
)
self.emit("}", depth)
field_names = [f'"{f.name}"' for f in t.fields]
self.emit(
textwrap.dedent(
f"""
impl<R> Node for {payload_name}<R> {{
const NAME: &'static str = "{t.name}";
const FIELD_NAMES: &'static [&'static str] = &[{', '.join(field_names)}];
}}
impl<R> From<{payload_name}<R>> for {rust_name}<R> {{
fn from(payload: {payload_name}<R>) -> Self {{
{rust_name}::{t.name}(payload)
}}
}}
"""
),
depth,
)
self.emit("", depth)
def visitConstructor(self, cons, parent, depth):
if cons.fields:
self.emit(f"{cons.name} {{", depth)
for f in cons.fields:
self.visit(f, parent, "", depth + 1, cons.name)
self.emit("},", depth)
else:
self.emit(f"{cons.name},", depth)
def visitField(self, field, parent, vis, depth, constructor=None):
typ = rust_type_name(field.type)
field_type = self.type_info.get(field.type)
if field_type and not field_type.is_simple:
typ = f"{typ}<R>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if (
field_type
and field_type.boxed
and (not (parent.product or field.seq) or field.opt)
):
typ = f"Box<{typ}>"
if field.opt or (
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict"
and field.name == "keys"
):
typ = f"Option<{typ}>"
if field.seq:
typ = f"Vec<{typ}>"
if typ == "Int":
typ = BUILTIN_INT_NAMES.get(field.name, typ)
name = rust_field(field.name)
self.emit(f"{vis}{name}: {typ},", depth)
def visitProduct(self, product, name, depth):
type_info = self.type_info[name]
product_name = rust_type_name(name)
self.emit_attrs(depth)
self.emit(f"pub struct {product_name}<R = TextRange> {{", depth)
for f in product.fields:
self.visit(f, type_info, "pub ", depth + 1)
assert bool(product.attributes) == type_info.no_cfg(self.type_info)
self.emit_range(product.attributes, depth + 1)
self.emit("}", depth)
field_names = [f'"{f.name}"' for f in product.fields]
self.emit(
f"""
impl<R> Node for {product_name}<R> {{
const NAME: &'static str = "{name}";
const FIELD_NAMES: &'static [&'static str] = &[
{', '.join(field_names)}
];
}}
""",
depth,
)
class FoldTraitDefVisitor(EmitVisitor):
def visitModule(self, mod, depth):
self.emit("pub trait Fold<U> {", depth)
self.emit("type TargetU;", depth + 1)
self.emit("type Error;", depth + 1)
self.emit(
"""
fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error>;
#[cfg(feature = "all-nodes-with-ranges")]
fn map_user_cfg(&mut self, user: U) -> Result<Self::TargetU, Self::Error> {
self.map_user(user)
}
#[cfg(not(feature = "all-nodes-with-ranges"))]
fn map_user_cfg(&mut self, _user: crate::EmptyRange<U>) -> Result<crate::EmptyRange<Self::TargetU>, Self::Error> {
Ok(crate::EmptyRange::default())
}
""",
depth + 1,
)
self.emit(
"""
fn fold<X: Foldable<U, Self::TargetU>>(&mut self, node: X) -> Result<X::Mapped, Self::Error> {
node.fold(self)
}""",
depth + 1,
)
for dfn in mod.dfns:
self.visit(dfn, depth + 2)
self.emit("}", depth)
def visitType(self, type, depth):
name = type.name
apply_u, apply_target_u = self.apply_generics(name, "U", "Self::TargetU")
enum_name = rust_type_name(name)
self.emit(
f"fn fold_{name}(&mut self, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, Self::Error> {{",
depth,
)
self.emit(f"fold_{name}(self, node)", depth + 1)
self.emit("}", depth)
class FoldImplVisitor(EmitVisitor):
def visitModule(self, mod, depth):
for dfn in mod.dfns:
self.visit(dfn, depth)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
type_info = self.type_info[name]
apply_t, apply_u, apply_target_u = self.apply_generics(
name, "T", "U", "F::TargetU"
)
enum_name = rust_type_name(name)
simple = is_simple(sum)
self.emit(f"impl<T, U> Foldable<T, U> for {enum_name}{apply_t} {{", depth)
self.emit(f"type Mapped = {enum_name}{apply_u};", depth + 1)
self.emit(
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
depth + 1,
)
self.emit(f"folder.fold_{name}(self)", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, F::Error> {{",
depth,
)
if simple:
self.emit("Ok(node) }", depth + 1)
return
self.emit("match node {", depth + 1)
for cons in sum.types:
fields_pattern = self.make_pattern(enum_name, cons.name, cons.fields)
self.emit(
f"{fields_pattern[0]} {{ {fields_pattern[1]}}} {fields_pattern[2]} => {{",
depth + 2,
)
map_user_suffix = "" if type_info.has_attributes else "_cfg"
self.emit(
f"let range = folder.map_user{map_user_suffix}(range)?;", depth + 3
)
self.gen_construction(
fields_pattern[0], cons.fields, fields_pattern[2], depth + 3
)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def visitProduct(self, product, name, depth):
apply_t, apply_u, apply_target_u = self.apply_generics(
name, "T", "U", "F::TargetU"
)
struct_name = rust_type_name(name)
has_attributes = bool(product.attributes)
self.emit(f"impl<T, U> Foldable<T, U> for {struct_name}{apply_t} {{", depth)
self.emit(f"type Mapped = {struct_name}{apply_u};", depth + 1)
self.emit(
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
depth + 1,
)
self.emit(f"folder.fold_{name}(self)", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {struct_name}{apply_u}) -> Result<{struct_name}{apply_target_u}, F::Error> {{",
depth,
)
fields_pattern = self.make_pattern(struct_name, struct_name, product.fields)
self.emit(f"let {struct_name} {{ {fields_pattern[1]} }} = node;", depth + 1)
map_user_suffix = "" if has_attributes else "_cfg"
self.emit(f"let range = folder.map_user{map_user_suffix}(range)?;", depth + 3)
self.gen_construction(struct_name, product.fields, "", depth + 1)
self.emit("}", depth)
def make_pattern(self, rust_name, fieldname: str, fields):
header = f"{rust_name}::{fieldname}({rust_name}{fieldname}"
footer = ")"
body = ",".join(rust_field(f.name) for f in fields)
if body:
body += ","
body += "range"
return header, body, footer
def gen_construction(self, header, fields, footer, depth):
self.emit(f"Ok({header} {{", depth)
for field in fields:
name = rust_field(field.name)
self.emit(f"{name}: Foldable::fold({name}, folder)?,", depth + 1)
self.emit("range,", depth + 1)
self.emit(f"}}{footer})", depth)
class FoldModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit("use crate::fold_helpers::Foldable;", depth)
FoldTraitDefVisitor(self.file, self.type_info).visit(mod, depth)
FoldImplVisitor(self.file, self.type_info).visit(mod, depth)
class VisitorTraitDefVisitor(StructVisitor):
def full_name(self, name):
type_info = self.type_info[name]
if type_info.enum_name:
return f"{type_info.enum_name}_{name}"
else:
return name
def node_type_name(self, name):
type_info = self.type_info[name]
if type_info.enum_name:
return f"{rust_type_name(type_info.enum_name)}{rust_type_name(name)}"
else:
return rust_type_name(name)
def visitModule(self, mod, depth):
self.emit("pub trait Visitor<R=crate::text_size::TextRange> {", depth)
for dfn in mod.dfns:
self.visit(dfn, depth + 1)
self.emit("}", depth)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
if is_simple(sum):
self.simple_sum(sum, name, depth)
else:
self.sum_with_constructors(sum, name, depth)
def emit_visitor(self, nodename, depth, has_node=True):
type_info = self.type_info[nodename]
node_type = type_info.rust_sum_name
(generic,) = self.apply_generics(nodename, "R")
self.emit(
f"fn visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{",
depth,
)
if has_node:
self.emit(f"self.generic_visit_{type_info.sum_name}(node)", depth + 1)
self.emit("}", depth)
def emit_generic_visitor_signature(self, nodename, depth, has_node=True):
type_info = self.type_info[nodename]
if has_node:
node_type = type_info.rust_sum_name
else:
node_type = "()"
(generic,) = self.apply_generics(nodename, "R")
self.emit(
f"fn generic_visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{",
depth,
)
def emit_empty_generic_visitor(self, nodename, depth):
self.emit_generic_visitor_signature(nodename, depth)
self.emit("}", depth)
def simple_sum(self, sum, name, depth):
self.emit_visitor(name, depth)
self.emit_empty_generic_visitor(name, depth)
def visit_match_for_type(self, nodename, rust_name, type_, depth):
self.emit(f"{rust_name}::{type_.name}", depth)
self.emit("(data)", depth)
self.emit(f"=> self.visit_{nodename}_{type_.name}(data),", depth)
def visit_sum_type(self, name, type_, depth):
self.emit_visitor(type_.name, depth, has_node=type_.fields)
if not type_.fields:
return
self.emit_generic_visitor_signature(type_.name, depth, has_node=True)
for f in type_.fields:
fieldname = rust_field(f.name)
field_type = self.type_info.get(f.type)
if not (field_type and field_type.has_user_data):
continue
if f.opt:
self.emit(f"if let Some(value) = node.{fieldname} {{", depth + 1)
elif f.seq:
iterable = f"node.{fieldname}"
if type_.name == "Dict" and f.name == "keys":
iterable = f"{iterable}.into_iter().flatten()"
self.emit(f"for value in {iterable} {{", depth + 1)
else:
self.emit("{", depth + 1)
self.emit(f"let value = node.{fieldname};", depth + 2)
variable = "value"
if field_type.boxed and (not f.seq or f.opt):
variable = "*" + variable
type_info = self.type_info[field_type.name]
self.emit(f"self.visit_{type_info.sum_name}({variable});", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def sum_with_constructors(self, sum, name, depth):
if not sum.attributes:
return
enum_name = rust_type_name(name)
self.emit_visitor(name, depth)
self.emit_generic_visitor_signature(name, depth)
depth += 1
self.emit("match node {", depth)
for t in sum.types:
self.visit_match_for_type(name, enum_name, t, depth + 1)
self.emit("}", depth)
depth -= 1
self.emit("}", depth)
# Now for the visitors for the types
for t in sum.types:
self.visit_sum_type(name, t, depth)
def visitProduct(self, product, name, depth):
self.emit_visitor(name, depth)
self.emit_empty_generic_visitor(name, depth)
class VisitorModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit("#[allow(unused_variables, non_snake_case)]", depth)
VisitorTraitDefVisitor(self.file, self.type_info).visit(mod, depth)
class ToPyo3AstVisitor(EmitVisitor):
"""Visitor to generate type-defs for AST."""
def __init__(self, namespace, *args, **kw):
super().__init__(*args, **kw)
self.namespace = namespace
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitProduct(self, product, name, depth=0):
rust_name = rust_type_name(name)
self.emit(
f"""
// product
impl ToPyo3Ast for crate::{self.namespace}::{rust_name} {{
fn to_pyo3_ast(&self, _py: Python) -> PyResult<Py<PyAny>> {{
let class = ranged::{rust_name}::py_type_cell().get().unwrap();
let instance = class.call1(_py, (
""",
0,
)
for field in product.fields:
self.emit(f"""self.{field.name}.to_pyo3_ast(_py)?,""", depth + 1)
self.emit(
"""
))?;
Ok(instance.into())
}
}
""",
0,
)
def visitSum(self, sum, name, depth=0):
rust_name = rust_type_name(name)
self.emit(
f"""
impl ToPyo3Ast for crate::{self.namespace}::{rust_name} {{
fn to_pyo3_ast(&self, _py: Python) -> PyResult<Py<PyAny>> {{
let instance = match &self {{
""",
0,
)
for cons in sum.types:
if not is_simple(sum):
self.emit(
f"""crate::{rust_name}::{cons.name}(cons) => cons.to_pyo3_ast(_py)?,""",
depth,
)
else:
self.emit(
f"""crate::{rust_name}::{cons.name} => ranged::{rust_name}{cons.name}::py_type_cell().get().unwrap().clone(),""",
depth,
)
self.emit(
"""
};
Ok(instance)
}
}
""",
0,
)
if is_simple(sum):
return
for cons in sum.types:
self.visit(cons, rust_name, depth)
def visitConstructor(self, cons, parent, depth):
self.emit(
f"""
// constructor
impl ToPyo3Ast for crate::{self.namespace}::{parent}{cons.name} {{
fn to_pyo3_ast(&self, _py: Python) -> PyResult<Py<PyAny>> {{
let class = ranged::{parent}{cons.name}::py_type_cell().get().unwrap();
let instance = class.call1(_py, (
""",
depth,
)
for field in cons.fields:
self.emit(
f"self.{rust_field(field.name)}.to_pyo3_ast(_py)?,",
depth + 1,
)
self.emit(
"""
))?;
Ok(instance.into())
}
}
""",
depth,
)
class RangedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
if info.is_simple:
return
sum_match_arms = ""
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += (
f" Self::{variant_info.rust_name}(node) => node.range(),"
)
self.emit_type_alias(variant_info)
self.emit_ranged_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Ranged for crate::{info.rust_sum_name} {{
fn range(&self) -> TextRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(),
0,
)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
self.emit_ranged_impl(info)
def emit_type_alias(self, info):
generics = "" if info.is_simple else "::<TextRange>"
self.emit(
f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};",
0,
)
self.emit("", 0)
def emit_ranged_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.file.write(
f"""
impl Ranged for crate::generic::{info.rust_sum_name}::<TextRange> {{
fn range(&self) -> TextRange {{
self.range
}}
}}
""".strip()
)
class LocatedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
if info.is_simple:
return
sum_match_arms = ""
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += (
f" Self::{variant_info.rust_name}(node) => node.range(),"
)
self.emit_type_alias(variant_info)
self.emit_located_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(),
0,
)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
self.emit_located_impl(info)
def emit_type_alias(self, info):
generics = "" if info.is_simple else "::<SourceRange>"
self.emit(
f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};",
0,
)
self.emit("", 0)
def emit_located_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
self.range
}}
}}
""",
0,
)
class Pyo3StructVisitor(EmitVisitor):
"""Visitor to generate type-defs for AST."""
def __init__(self, namespace, *args, borrow=False, **kw):
self.namespace = namespace
self.borrow = borrow
super().__init__(*args, **kw)
@property
def module_name(self):
name = f"rustpython_ast.{self.namespace}"
return name
@property
def ref_def(self):
return "&'static " if self.borrow else ""
@property
def ref(self):
return "&" if self.borrow else ""
def emit_class(self, name, rust_name, subclass, base="super::AST"):
if subclass:
subclass = ", subclass"
body = ""
into = f"{rust_name}"
else:
subclass = ""
body = f"(pub {self.ref_def} crate::{self.namespace}::{rust_name})"
into = f"{rust_name}(node)"
self.emit(
textwrap.dedent(
f"""
#[pyclass(module="{self.module_name}", name="_{name}", extends={base}{subclass})]
#[derive(Clone, Debug)]
pub struct {rust_name} {body};
impl {rust_name} {{
#[inline]
pub fn py_type_cell() -> &'static OnceCell<Py<PyAny>> {{
static PY_TYPE: OnceCell<Py<PyAny>> = OnceCell::new();
&PY_TYPE
}}
}}
impl From<{self.ref_def} crate::{self.namespace}::{rust_name}> for {rust_name} {{
fn from({"" if body else "_"}node: {self.ref_def} crate::{self.namespace}::{rust_name}) -> Self {{
{into}
}}
}}
"""
),
0,
)
if subclass:
self.emit(
textwrap.dedent(
f"""
#[pymethods]
impl {rust_name} {{
#[new]
fn new() -> PyClassInitializer<Self> {{
PyClassInitializer::from(AST)
.add_subclass(Self)
}}
}}
impl ToPyObject for {rust_name} {{
fn to_object(&self, py: Python) -> PyObject {{
let initializer = PyClassInitializer::from(AST)
.add_subclass(self.clone());
Py::new(py, initializer).unwrap().into_py(py)
}}
}}
"""
),
0,
)
else:
if base != "super::AST":
add_subclass = f".add_subclass({base})"
else:
add_subclass = ""
self.emit(
textwrap.dedent(
f"""
impl ToPyObject for {rust_name} {{
fn to_object(&self, py: Python) -> PyObject {{
let initializer = PyClassInitializer::from(AST)
{add_subclass}
.add_subclass(self.clone());
Py::new(py, initializer).unwrap().into_py(py)
}}
}}
"""
),
0,
)
if self.borrow and not subclass:
self.emit_wrapper(rust_name)
def emit_getter(self, owner, type_name):
self.emit(
textwrap.dedent(
f"""
#[pymethods]
impl {type_name} {{
"""
),
0,
)
for field in owner.fields:
self.emit(
textwrap.dedent(
f"""
#[getter]
#[inline]
fn get_{field.name}(&self, py: Python) -> PyResult<PyObject> {{
self.0.{rust_field(field.name)}.to_pyo3_wrapper(py)
}}
"""
),
3,
)
self.emit(
textwrap.dedent(
"""
}
"""
),
0,
)
def emit_getattr(self, owner, type_name):
self.emit(
textwrap.dedent(
f"""
#[pymethods]
impl {type_name} {{
fn __getattr__(&self, py: Python, key: &str) -> PyResult<PyObject> {{
let object: Py<PyAny> = match key {{
"""
),
0,
)
for field in owner.fields:
self.emit(
f'"{field.name}" => self.0.{rust_field(field.name)}.to_pyo3_wrapper(py)?,',
3,
)
self.emit(
textwrap.dedent(
"""
_ => todo!(),
};
Ok(object)
}
}
"""
),
0,
)
def emit_wrapper(self, rust_name):
self.emit(
f"""
impl ToPyo3Wrapper for crate::{self.namespace}::{rust_name} {{
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
Ok({rust_name}(self).to_object(py))
}}
}}
""",
0,
)
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth=0):
rust_name = rust_type_name(name)
self.emit_class(name, rust_name, True)
simple = is_simple(sum)
if self.borrow:
self.emit(
f"""
impl ToPyo3Wrapper for crate::{self.namespace}::{rust_name} {{
#[inline]
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
match &self {{
""",
0,
)
for cons in sum.types:
if simple:
self.emit(
f"Self::{cons.name} => Ok({rust_name}{cons.name}.to_object(py)),",
3,
)
else:
self.emit(
f"Self::{cons.name}(cons) => cons.to_pyo3_wrapper(py),", 3
)
self.emit(
f"""
}}
}}
}}
""",
0,
)
for cons in sum.types:
self.visit(cons, rust_name, simple, depth + 1)
def visitProduct(self, product, name, depth=0):
rust_name = rust_type_name(name)
self.emit_class(name, rust_name, False)
if self.borrow:
self.emit_getter(product, rust_name)
def visitConstructor(self, cons, parent, simple, depth):
if simple:
self.emit(
f"""
#[pyclass(module="{self.module_name}", name="_{cons.name}", extends={parent})]
pub struct {parent}{cons.name};
impl {parent}{cons.name} {{
#[inline]
pub fn py_type_cell() -> &'static OnceCell<Py<PyAny>> {{
static PY_TYPE: OnceCell<Py<PyAny>> = OnceCell::new();
&PY_TYPE
}}
}}
impl ToPyObject for {parent}{cons.name} {{
fn to_object(&self, py: Python) -> PyObject {{
let initializer = PyClassInitializer::from(AST)
.add_subclass({parent})
.add_subclass(Self);
Py::new(py, initializer).unwrap().into_py(py)
}}
}}
""",
depth,
)
else:
self.emit_class(
cons.name, f"{parent}{cons.name}", subclass=False, base=parent
)
if self.borrow:
self.emit_getter(cons, f"{parent}{cons.name}")
class Pyo3PymoduleVisitor(EmitVisitor):
def __init__(self, namespace, *args, **kw):
self.namespace = namespace
super().__init__(*args, **kw)
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitProduct(self, product, name, depth=0):
rust_name = rust_type_name(name)
self.emit_fields(name, rust_name, False, depth)
def visitSum(self, sum, name, depth):
rust_name = rust_type_name(name)
simple = is_simple(sum)
self.emit_fields(name, rust_name, True, depth)
for cons in sum.types:
self.visit(cons, name, simple, depth)
def visitConstructor(self, cons, parent, simple, depth):
rust_name = rust_type_name(parent) + rust_type_name(cons.name)
self.emit_fields(cons.name, rust_name, simple, depth)
def emit_fields(self, name, rust_name, simple, depth):
if simple:
call = ".call0().unwrap()"
else:
call = ""
self.emit(
f"""
{rust_name}::py_type_cell().get_or_init(|| {{
ast_module.getattr("{name}").unwrap(){call}.into_py(py)
}});
""",
depth,
)
if simple:
return
self.emit(
f"""
{{
m.add_class::<{rust_name}>()?;
let node = m.getattr("_{name}")?;
m.setattr("{name}", node)?;
let names: Vec<&'static str> = crate::{self.namespace}::{rust_name}::FIELD_NAMES.to_vec();
let fields = PyTuple::new(py, names);
node.setattr("_fields", fields)?;
}}
""",
depth,
)
class StdlibClassDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
struct_name = "Node" + rust_type_name(name)
self.emit(
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "NodeAst")]',
depth,
)
self.emit(f"struct {struct_name};", depth)
self.emit("#[pyclass(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {struct_name} {{}}", depth)
for cons in sum.types:
self.visit(cons, sum.attributes, struct_name, depth)
def visitConstructor(self, cons, attrs, base, depth):
self.gen_class_def(cons.name, cons.fields, attrs, depth, base)
def visitProduct(self, product, name, depth):
self.gen_class_def(name, product.fields, product.attributes, depth)
def gen_class_def(self, name, fields, attrs, depth, base=None):
if base is None:
base = "NodeAst"
struct_name = "Node" + rust_type_name(name)
else:
struct_name = base + rust_type_name(name)
self.emit(
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = {json.dumps(base)})]',
depth,
)
self.emit(f"struct {struct_name};", depth)
self.emit("#[pyclass(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {struct_name} {{", depth)
self.emit("#[extend_class]", depth + 1)
self.emit(
"fn extend_class_with_fields(ctx: &Context, class: &'static Py<PyType>) {",
depth + 1,
)
fields = ",".join(
f"ctx.new_str(ascii!({json.dumps(f.name)})).into()" for f in fields
)
self.emit(
f"class.set_attr(identifier!(ctx, _fields), ctx.new_tuple(vec![{fields}]).into());",
depth + 2,
)
attrs = ",".join(
f"ctx.new_str(ascii!({json.dumps(attr.name)})).into()" for attr in attrs
)
self.emit(
f"class.set_attr(identifier!(ctx, _attributes), ctx.new_list(vec![{attrs}]).into());",
depth + 2,
)
self.emit("}", depth + 1)
self.emit("}", depth)
class StdlibExtendModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit(
"pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py<PyModule>) {",
depth,
)
self.emit("extend_module!(vm, module, {", depth + 1)
for dfn in mod.dfns:
self.visit(dfn, depth + 2)
self.emit("})", depth + 1)
self.emit("}", depth)
def visitType(self, type, depth):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
rust_name = rust_type_name(name)
self.emit(f"{json.dumps(name)} => Node{rust_name}::make_class(&vm.ctx),", depth)
for cons in sum.types:
self.visit(cons, depth, rust_name)
def visitConstructor(self, cons, depth, rust_name):
self.gen_extension(cons.name, depth, rust_name)
def visitProduct(self, product, name, depth):
self.gen_extension(name, depth)
def gen_extension(self, name, depth, base=""):
rust_name = rust_type_name(name)
self.emit(
f"{json.dumps(name)} => Node{base}{rust_name}::make_class(&vm.ctx),", depth
)
class StdlibTraitImplVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
rust_name = rust_type_name(name)
self.emit(f"impl NamedNode for ast::located::{rust_name} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit("// sum", depth)
self.emit(f"impl Node for ast::located::{rust_name} {{", depth)
self.emit(
"fn ast_to_object(self, vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
simple = is_simple(sum)
if simple:
self.emit("let node_type = match self {", depth + 2)
for cons in sum.types:
self.emit(
f"ast::located::{rust_name}::{cons.name} => Node{rust_name}{cons.name}::static_type(),",
depth,
)
self.emit("};", depth + 3)
self.emit("NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()", depth + 2)
else:
self.emit("match self {", depth + 2)
for cons in sum.types:
self.emit(
f"ast::located::{rust_name}::{cons.name}(cons) => cons.ast_to_object(vm),",
depth + 3,
)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_sum_from_object(sum, name, rust_name, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
if not is_simple(sum):
for cons in sum.types:
self.visit(cons, sum, rust_name, depth)
def visitConstructor(self, cons, sum, sum_rust_name, depth):
rust_name = rust_type_name(cons.name)
self.emit("// constructor", depth)
self.emit(
f"impl NamedNode for ast::located::{sum_rust_name}{rust_name} {{", depth
)
self.emit(f"const NAME: &'static str = {json.dumps(cons.name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::located::{sum_rust_name}{rust_name} {{", depth)
fields_pattern = self.make_pattern(cons.fields)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
self.emit(
f"let ast::located::{sum_rust_name}{rust_name} {{ {fields_pattern} }} = self;",
depth,
)
self.make_node(cons.name, sum, cons.fields, depth + 2, sum_rust_name)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_product_from_object(
cons, cons.name, f"{sum_rust_name}{rust_name}", sum.attributes, depth + 2
)
self.emit("}", depth + 1)
self.emit("}", depth + 1)
def visitProduct(self, product, name, depth):
struct_name = rust_type_name(name)
self.emit("// product", depth)
self.emit(f"impl NamedNode for ast::located::{struct_name} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::located::{struct_name} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
fields_pattern = self.make_pattern(product.fields)
self.emit(
f"let ast::located::{struct_name} {{ {fields_pattern} }} = self;",
depth + 2,
)
self.make_node(name, product, product.fields, depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_product_from_object(
product, name, struct_name, product.attributes, depth + 2
)
self.emit("}", depth + 1)
self.emit("}", depth)
def make_node(self, variant, owner, fields, depth, base=""):
rust_variant = rust_type_name(variant)
self.emit(
f"let node = NodeAst.into_ref_with_type(_vm, Node{base}{rust_variant}::static_type().to_owned()).unwrap();",
depth,
)
if fields or owner.attributes:
self.emit("let dict = node.as_object().dict().unwrap();", depth)
for f in fields:
self.emit(
f"dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();",
depth,
)
if owner.attributes:
self.emit("node_add_location(&dict, _range, _vm);", depth)
self.emit("node.into()", depth)
def make_pattern(self, fields):
return "".join(f"{rust_field(f.name)}," for f in fields) + "range: _range"
def gen_sum_from_object(self, sum, sum_name, rust_name, depth):
# if sum.attributes:
# self.extract_location(sum_name, depth)
self.emit("let _cls = _object.class();", depth)
self.emit("Ok(", depth)
for cons in sum.types:
self.emit(
f"if _cls.is(Node{rust_name}{cons.name}::static_type()) {{", depth
)
self.emit(f"ast::located::{rust_name}::{cons.name}", depth + 1)
if not is_simple(sum):
self.emit(
f"(ast::located::{rust_name}{cons.name}::ast_from_object(_vm, _object)?)",
depth + 1,
)
self.emit("} else", depth)
self.emit("{", depth)
msg = f'format!("expected some sort of {sum_name}, but got {{}}",_object.repr(_vm)?)'
self.emit(f"return Err(_vm.new_type_error({msg}));", depth + 1)
self.emit("})", depth)
def gen_product_from_object(
self, product, product_name, struct_name, has_attributes, depth
):
self.emit("Ok(", depth)
self.gen_construction(
struct_name, product, product_name, has_attributes, depth + 1
)
self.emit(")", depth)
def gen_construction_fields(self, cons, name, depth):
for field in cons.fields:
self.emit(
f"{rust_field(field.name)}: {self.decode_field(field, name)},",
depth + 1,
)
def gen_construction(self, cons_path, cons, name, attributes, depth):
self.emit(f"ast::located::{cons_path} {{", depth)
self.gen_construction_fields(cons, name, depth + 1)
if attributes:
self.emit(f'range: range_from_object(_vm, _object, "{name}")?,', depth + 1)
else:
self.emit("range: Default::default(),", depth + 1)
self.emit("}", depth)
def extract_location(self, typename, depth):
row = self.decode_field(asdl.Field("int", "lineno"), typename)
column = self.decode_field(asdl.Field("int", "col_offset"), typename)
self.emit(
f"""
let _location = {{
let row = {row};
let column = {column};
try_location(row, column)
}};""",
depth,
)
def decode_field(self, field, typename):
name = json.dumps(field.name)
if field.opt and not field.seq:
return f"get_node_field_opt(_vm, &_object, {name})?.map(|obj| Node::ast_from_object(_vm, obj)).transpose()?"
else:
return f"Node::ast_from_object(_vm, get_node_field(_vm, &_object, {name}, {json.dumps(typename)})?)?"
class ChainOfVisitors:
def __init__(self, *visitors):
self.visitors = visitors
def visit(self, object):
for v in self.visitors:
v.visit(object)
v.emit("", 0)
def write_ast_def(mod, type_info, f):
f.write("use crate::text_size::TextRange;")
StructVisitor(f, type_info).visit(mod)
def write_fold_def(mod, type_info, f):
FoldModuleVisitor(f, type_info).visit(mod)
def write_visitor_def(mod, type_info, f):
VisitorModuleVisitor(f, type_info).visit(mod)
def write_ranged_def(mod, type_info, f):
RangedDefVisitor(f, type_info).visit(mod)
def write_located_def(mod, type_info, f):
LocatedDefVisitor(f, type_info).visit(mod)
def write_ast_pyo3(mod, type_info, namespace, f):
ToPyo3AstVisitor(namespace, f, type_info).visit(mod)
def write_pyo3_def(mod, type_info, namespace, borrow, f):
Pyo3StructVisitor(namespace, f, type_info, borrow=borrow).visit(mod)
f.write(
"""
use once_cell::sync::OnceCell;
pub fn add_to_module(py: Python, m: &PyModule) -> PyResult<()> {
super::init_module(py, m)?;
let ast_module = PyModule::import(py, "_ast")?;
"""
)
Pyo3PymoduleVisitor(namespace, f, type_info).visit(mod)
f.write("Ok(())\n}")
def write_ast_mod(mod, type_info, f):
f.write(
textwrap.dedent(
"""
#![allow(clippy::all)]
use super::*;
use crate::common::ascii;
"""
)
)
c = ChainOfVisitors(
StdlibClassDefVisitor(f, type_info),
StdlibTraitImplVisitor(f, type_info),
StdlibExtendModuleVisitor(f, type_info),
)
c.visit(mod)
def main(
input_filename,
ast_dir,
module_filename,
dump_module=False,
):
auto_gen_msg = AUTO_GEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
mod = asdl.parse(input_filename)
if dump_module:
print("Parsed Module:")
print(mod)
if not asdl.check(mod):
sys.exit(1)
type_info = {}
FindUserDataTypesVisitor(type_info).visit(mod)
from functools import partial as p
for filename, write in [
("generic", p(write_ast_def, mod, type_info)),
("fold", p(write_fold_def, mod, type_info)),
("ranged", p(write_ranged_def, mod, type_info)),
("located", p(write_located_def, mod, type_info)),
("visitor", p(write_visitor_def, mod, type_info)),
("to_pyo3_located", p(write_ast_pyo3, mod, type_info, "located")),
("to_pyo3_ranged", p(write_ast_pyo3, mod, type_info, "ranged")),
# ("pyo3_located", p(write_pyo3_def, mod, type_info, "located", True)),
("pyo3_ranged", p(write_pyo3_def, mod, type_info, "ranged", True)),
]:
with (ast_dir / f"{filename}.rs").open("w") as f:
f.write(auto_gen_msg)
write(f)
# for filename, write in [
# ]:
# with (pyo3_dir / f"{filename}.rs").open("w") as f:
# f.write(auto_gen_msg)
# write(mod, type_info, f)
with module_filename.open("w") as module_file:
module_file.write(auto_gen_msg)
write_ast_mod(mod, type_info, module_file)
print(f"{ast_dir}, {module_filename} regenerated.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("input_file", type=Path)
parser.add_argument("-A", "--ast-dir", type=Path, required=True)
parser.add_argument("-M", "--module-file", type=Path, required=True)
parser.add_argument("-d", "--dump-module", action="store_true")
args = parser.parse_args()
main(
args.input_file,
args.ast_dir,
args.module_file,
args.dump_module,
)