Move range from Attributed to Nodes (#22)

* Move `range` from `Attributed` to `Node`s

* No Attributed + custom for Range PoC

* Generate all located variants, generate enum implementations

* Implement `Copy` on simple enums

* Move `Suite` to `ranged` and `located`

* Update tests
---------

Co-authored-by: Jeong YunWon <jeong@youknowone.org>
This commit is contained in:
Micha Reiser 2023-05-15 08:08:12 +02:00 committed by GitHub
parent a983f4383f
commit 192379cede
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
126 changed files with 29410 additions and 30670 deletions

View file

@ -1,6 +1,6 @@
# spell-checker:words dfn dfns
#! /usr/bin/env python
# ! /usr/bin/env python
"""Generate Rust code from an ASDL description."""
import sys
@ -84,6 +84,7 @@ class TypeInfo:
enum_name: Optional[str]
has_user_data: Optional[bool]
has_attributes: bool
is_simple: bool
empty_field: bool
children: set
boxed: bool
@ -95,6 +96,7 @@ class TypeInfo:
self.enum_name = None
self.has_user_data = None
self.has_attributes = False
self.is_simple = False
self.empty_field = False
self.children = set()
self.boxed = False
@ -104,6 +106,14 @@ class TypeInfo:
def __repr__(self):
return f"<TypeInfo: {self.name}>"
def no_cfg(self, typeinfo):
if self.product:
return self.has_attributes
elif self.enum_name:
return typeinfo[self.enum_name].has_attributes
else:
return self.has_attributes
@property
def rust_name(self):
return rust_type_name(self.name)
@ -124,19 +134,6 @@ class TypeInfo:
name = rust_type_name(self.enum_name) + rust_name
return name
@property
def rust_suffix(self):
if self.product:
if self.has_attributes:
return "Data"
else:
return ""
else:
if self.has_attributes:
return "Kind"
else:
return ""
def determine_user_data(self, type_info, stack):
if self.name in stack:
return None
@ -160,7 +157,8 @@ class TypeInfoMixin:
return self.type_info[typ].has_user_data
def apply_generics(self, typ, *generics):
if self.has_user_data(typ):
needs_generics = not self.type_info[typ].is_simple
if needs_generics:
return [f"<{g}>" for g in generics]
else:
return ["" for g in generics]
@ -208,6 +206,7 @@ class FindUserDataTypesVisitor(asdl.VisitorBase):
info = self.type_info[name]
if is_simple(sum):
info.has_user_data = False
info.is_simple = True
else:
for t in sum.types:
t_info = TypeInfo(t.name)
@ -218,7 +217,7 @@ class FindUserDataTypesVisitor(asdl.VisitorBase):
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
# attributes means located, which has the `custom: U` field
# attributes means located, which has the `range: R` field
info.has_user_data = True
info.has_attributes = True
@ -228,7 +227,7 @@ class FindUserDataTypesVisitor(asdl.VisitorBase):
def visitProduct(self, product, name):
info = self.type_info[name]
if product.attributes:
# attributes means located, which has the `custom: U` field
# attributes means located, which has the `range: R` field
info.has_user_data = True
info.has_attributes = True
info.has_expr = product_has_expr(product)
@ -277,10 +276,16 @@ class StructVisitor(EmitVisitor):
def emit_attrs(self, depth):
self.emit("#[derive(Clone, Debug, PartialEq)]", depth)
def emit_range(self, has_attributes, depth):
if has_attributes:
self.emit("pub range: R,", depth + 1)
else:
self.emit("pub range: crate::ranged::OptionalRange<R>,", depth + 1)
def simple_sum(self, sum, name, depth):
rust_name = rust_type_name(name)
self.emit_attrs(depth)
self.emit("#[derive(is_macro::Is)]", depth)
self.emit("#[derive(is_macro::Is, Copy, Hash, Eq)]", depth)
self.emit(f"pub enum {rust_name} {{", depth)
for variant in sum.types:
self.emit(f"{variant.name},", depth + 1)
@ -289,20 +294,16 @@ class StructVisitor(EmitVisitor):
def sum_with_constructors(self, sum, name, depth):
type_info = self.type_info[name]
suffix = type_info.rust_suffix
rust_name = rust_type_name(name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Attributed<>
for t in sum.types:
if not t.fields:
continue
self.sum_subtype_struct(type_info, t, rust_name, depth)
generics, generics_applied = self.apply_generics(name, "U = ()", "U")
self.emit_attrs(depth)
self.emit("#[derive(is_macro::Is)]", depth)
self.emit(f"pub enum {rust_name}{suffix}{generics} {{", depth)
self.emit(f"pub enum {rust_name}<R = TextRange> {{", depth)
needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types)
for t in sum.types:
if needs_escape:
@ -310,35 +311,29 @@ class StructVisitor(EmitVisitor):
f'#[is(name = "{rust_field_name(t.name)}_{rust_name.lower()}")]',
depth + 1,
)
if t.fields:
(t_generics_applied,) = self.apply_generics(t.name, "U")
self.emit(
f"{t.name}({rust_name}{t.name}{t_generics_applied}),", depth + 1
)
else:
self.emit(f"{t.name},", depth + 1)
self.emit(f"{t.name}({rust_name}{t.name}<R>),", depth + 1)
self.emit("}", depth)
if type_info.has_attributes:
self.emit(
f"pub type {rust_name}<U = ()> = Attributed<{rust_name}{suffix}{generics_applied}, U>;",
depth,
)
self.emit("", depth)
def sum_subtype_struct(self, sum_type_info, t, rust_name, depth):
self.emit_attrs(depth)
generics, generics_applied = self.apply_generics(t.name, "U = ()", "U")
payload_name = f"{rust_name}{t.name}"
self.emit(f"pub struct {payload_name}{generics} {{", depth)
self.emit(f"pub struct {payload_name}<R = TextRange> {{", depth)
self.emit_range(sum_type_info.has_attributes, depth)
for f in t.fields:
self.visit(f, sum_type_info, "pub ", depth + 1, t.name)
assert sum_type_info.has_attributes == self.type_info[t.name].no_cfg(
self.type_info
)
self.emit("}", depth)
self.emit(
textwrap.dedent(
f"""
impl{generics_applied} From<{payload_name}{generics_applied}> for {rust_name}{sum_type_info.rust_suffix}{generics_applied} {{
fn from(payload: {payload_name}{generics_applied}) -> Self {{
{rust_name}{sum_type_info.rust_suffix}::{t.name}(payload)
impl<R> From<{payload_name}<R>> for {rust_name}<R> {{
fn from(payload: {payload_name}<R>) -> Self {{
{rust_name}::{t.name}(payload)
}}
}}
"""
@ -346,6 +341,8 @@ class StructVisitor(EmitVisitor):
depth,
)
self.emit("", depth)
def visitConstructor(self, cons, parent, depth):
if cons.fields:
self.emit(f"{cons.name} {{", depth)
@ -358,21 +355,21 @@ class StructVisitor(EmitVisitor):
def visitField(self, field, parent, vis, depth, constructor=None):
typ = rust_type_name(field.type)
field_type = self.type_info.get(field.type)
if field_type and field_type.has_user_data:
typ = f"{typ}<U>"
if field_type and not field_type.is_simple:
typ = f"{typ}<R>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if (
field_type
and field_type.boxed
and (not (parent.product or field.seq) or field.opt)
field_type
and field_type.boxed
and (not (parent.product or field.seq) or field.opt)
):
typ = f"Box<{typ}>"
if field.opt or (
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict"
and field.name == "keys"
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict"
and field.name == "keys"
):
typ = f"Option<{typ}>"
if field.seq:
@ -384,28 +381,15 @@ class StructVisitor(EmitVisitor):
def visitProduct(self, product, name, depth):
type_info = self.type_info[name]
generics, generics_applied = self.apply_generics(name, "U = ()", "U")
data_name = rust_name = rust_type_name(name)
if product.attributes:
data_name = rust_name + "Data"
product_name = rust_type_name(name)
self.emit_attrs(depth)
has_expr = product_has_expr(product)
if has_expr:
data_def = f"{data_name}{generics}"
else:
data_def = data_name
self.emit(f"pub struct {data_def} {{", depth)
self.emit(f"pub struct {product_name}<R = TextRange> {{", depth)
for f in product.fields:
self.visit(f, type_info, "pub ", depth + 1)
assert bool(product.attributes) == type_info.no_cfg(self.type_info)
self.emit_range(product.attributes, depth + 1)
self.emit("}", depth)
if product.attributes:
# attributes should just be location info
if not has_expr:
generics_applied = ""
self.emit(
f"pub type {rust_name}<U = ()> = Attributed<{data_name}{generics_applied}, U>;",
depth,
)
self.emit("", depth)
@ -414,16 +398,18 @@ class FoldTraitDefVisitor(EmitVisitor):
self.emit("pub trait Fold<U> {", depth)
self.emit("type TargetU;", depth + 1)
self.emit("type Error;", depth + 1)
self.emit(
"fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error>;",
depth + 1,
)
self.emit(
"""
fn map_attributed<T>(&mut self, attributed: Attributed<T, U>) -> Result<Attributed<T, Self::TargetU>, Self::Error> {
let custom = self.map_user(attributed.custom)?;
Ok(Attributed { range: attributed.range, custom, node: attributed.node })
}""",
fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error>;
#[cfg(feature = "all-nodes-with-ranges")]
fn map_user_cfg(&mut self, user: U) -> Result<Self::TargetU, Self::Error> {
self.map_user(user)
}
#[cfg(not(feature = "all-nodes-with-ranges"))]
fn map_user_cfg(&mut self, _user: crate::EmptyRange<U>) -> Result<crate::EmptyRange<Self::TargetU>, Self::Error> {
Ok(crate::EmptyRange::default())
}
""",
depth + 1,
)
self.emit(
@ -451,15 +437,6 @@ class FoldTraitDefVisitor(EmitVisitor):
class FoldImplVisitor(EmitVisitor):
def visitModule(self, mod, depth):
self.emit(
"fn fold_attributed<U, F: Fold<U> + ?Sized, T, MT>(folder: &mut F, node: Attributed<T, U>, f: impl FnOnce(&mut F, T) -> Result<MT, F::Error>) -> Result<Attributed<MT, F::TargetU>, F::Error> {",
depth,
)
self.emit(
"let node = folder.map_attributed(node)?; Ok(Attributed { custom: node.custom, range: node.range, node: f(folder, node.node)? })",
depth + 1,
)
self.emit("}", depth)
for dfn in mod.dfns:
self.visit(dfn, depth)
@ -472,6 +449,7 @@ class FoldImplVisitor(EmitVisitor):
name, "T", "U", "F::TargetU"
)
enum_name = rust_type_name(name)
simple = is_simple(sum)
self.emit(f"impl<T, U> Foldable<T, U> for {enum_name}{apply_t} {{", depth)
self.emit(f"type Mapped = {enum_name}{apply_u};", depth + 1)
@ -487,25 +465,29 @@ class FoldImplVisitor(EmitVisitor):
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, F::Error> {{",
depth,
)
if type_info.has_attributes:
self.emit("fold_attributed(folder, node, |folder, node| {", depth)
if simple:
self.emit("Ok(node) }", depth + 1)
return
self.emit("match node {", depth + 1)
for cons in sum.types:
fields_pattern = self.make_pattern(
enum_name, type_info.rust_suffix, cons.name, cons.fields
)
fields_pattern = self.make_pattern(enum_name, cons.name, cons.fields)
self.emit(
f"{fields_pattern[0]} {{ {fields_pattern[1]} }} {fields_pattern[2]} => {{",
f"{fields_pattern[0]} {{ {fields_pattern[1]}}} {fields_pattern[2]} => {{",
depth + 2,
)
map_user_suffix = "" if type_info.has_attributes else "_cfg"
self.emit(
f"let range = folder.map_user{map_user_suffix}(range)?;", depth + 3
)
self.gen_construction(
fields_pattern[0], cons.fields, fields_pattern[2], depth + 3
)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
if type_info.has_attributes:
self.emit("})", depth)
self.emit("}", depth)
def visitProduct(self, product, name, depth):
@ -529,27 +511,26 @@ class FoldImplVisitor(EmitVisitor):
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {struct_name}{apply_u}) -> Result<{struct_name}{apply_target_u}, F::Error> {{",
depth,
)
if has_attributes:
self.emit("fold_attributed(folder, node, |folder, node| {", depth)
rust_name = struct_name + "Data"
else:
rust_name = struct_name
fields_pattern = self.make_pattern(rust_name, struct_name, None, product.fields)
self.emit(f"let {rust_name} {{ {fields_pattern[1]} }} = node;", depth + 1)
self.gen_construction(rust_name, product.fields, "", depth + 1)
if has_attributes:
self.emit("})", depth)
fields_pattern = self.make_pattern(struct_name, struct_name, product.fields)
self.emit(f"let {struct_name} {{ {fields_pattern[1]} }} = node;", depth + 1)
map_user_suffix = "" if has_attributes else "_cfg"
self.emit(f"let range = folder.map_user{map_user_suffix}(range)?;", depth + 3)
self.gen_construction(struct_name, product.fields, "", depth + 1)
self.emit("}", depth)
def make_pattern(self, rust_name, suffix, fieldname, fields):
if fields:
header = f"{rust_name}{suffix}::{fieldname}({rust_name}{fieldname}"
footer = ")"
else:
header = f"{rust_name}{suffix}::{fieldname}"
footer = ""
def make_pattern(self, rust_name, fieldname: str, fields):
header = f"{rust_name}::{fieldname}({rust_name}{fieldname}"
footer = ")"
body = ",".join(rust_field(f.name) for f in fields)
if body:
body += ","
body += "range"
return header, body, footer
def gen_construction(self, header, fields, footer, depth):
@ -557,6 +538,8 @@ class FoldImplVisitor(EmitVisitor):
for field in fields:
name = rust_field(field.name)
self.emit(f"{name}: Foldable::fold({name}, folder)?,", depth + 1)
self.emit("range,", depth + 1)
self.emit(f"}}{footer})", depth)
@ -584,7 +567,7 @@ class VisitorTraitDefVisitor(StructVisitor):
return rust_type_name(name)
def visitModule(self, mod, depth):
self.emit("pub trait Visitor<U=()> {", depth)
self.emit("pub trait Visitor<R=crate::text_size::TextRange> {", depth)
for dfn in mod.dfns:
self.visit(dfn, depth + 1)
@ -595,16 +578,14 @@ class VisitorTraitDefVisitor(StructVisitor):
def emit_visitor(self, nodename, depth, has_node=True):
type_info = self.type_info[nodename]
if has_node:
node_type = type_info.rust_sum_name
node_value = "node"
else:
node_type = "()"
node_value = "()"
node_type = type_info.rust_sum_name
generic, = self.apply_generics(nodename, "R")
self.emit(
f"fn visit_{type_info.sum_name}(&mut self, node: {node_type}) {{", depth
f"fn visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{", depth
)
self.emit(f"self.generic_visit_{type_info.sum_name}({node_value})", depth + 1)
if has_node:
self.emit(f"self.generic_visit_{type_info.sum_name}(node)", depth + 1)
self.emit("}", depth)
def emit_generic_visitor_signature(self, nodename, depth, has_node=True):
@ -613,8 +594,9 @@ class VisitorTraitDefVisitor(StructVisitor):
node_type = type_info.rust_sum_name
else:
node_type = "()"
generic, = self.apply_generics(nodename, "R")
self.emit(
f"fn generic_visit_{type_info.sum_name}(&mut self, node: {node_type}) {{",
f"fn generic_visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{",
depth,
)
@ -628,16 +610,15 @@ class VisitorTraitDefVisitor(StructVisitor):
def visit_match_for_type(self, nodename, rust_name, type_, depth):
self.emit(f"{rust_name}::{type_.name}", depth)
if type_.fields:
self.emit("(data)", depth)
data = "data"
else:
data = "()"
self.emit(f"=> self.visit_{nodename}_{type_.name}({data}),", depth)
self.emit("(data)", depth)
self.emit(f"=> self.visit_{nodename}_{type_.name}(data),", depth)
def visit_sum_type(self, name, type_, depth):
self.emit_visitor(type_.name, depth, has_node=type_.fields)
self.emit_generic_visitor_signature(type_.name, depth, has_node=type_.fields)
if not type_.fields:
return
self.emit_generic_visitor_signature(type_.name, depth, has_node=True)
for f in type_.fields:
fieldname = rust_field(f.name)
field_type = self.type_info.get(f.type)
@ -669,15 +650,13 @@ class VisitorTraitDefVisitor(StructVisitor):
if not sum.attributes:
return
rust_name = enum_name = rust_type_name(name)
if sum.attributes:
rust_name = enum_name + "Kind"
enum_name = rust_type_name(name)
self.emit_visitor(name, depth)
self.emit_generic_visitor_signature(name, depth)
depth += 1
self.emit("match node.node {", depth)
self.emit("match node {", depth)
for t in sum.types:
self.visit_match_for_type(name, rust_name, t, depth + 1)
self.visit_match_for_type(name, enum_name, t, depth + 1)
self.emit("}", depth)
depth -= 1
self.emit("}", depth)
@ -838,8 +817,6 @@ class TraitImplVisitor(EmitVisitor):
def visitProduct(self, product, name, depth):
struct_name = rust_type_name(name)
if product.attributes:
struct_name += "Data"
self.emit(f"impl NamedNode for ast::located::{struct_name} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
@ -945,6 +922,124 @@ class TraitImplVisitor(EmitVisitor):
return f"Node::ast_from_object(_vm, get_node_field(_vm, &_object, {name}, {json.dumps(typename)})?)?"
class RangedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
if info.is_simple:
return
sum_match_arms = ""
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += f" Self::{variant_info.rust_name}(node) => node.range(),"
self.emit_ranged_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(f"""
impl Ranged for crate::{info.rust_sum_name} {{
fn range(&self) -> TextRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(), 0)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.emit_ranged_impl(info)
def emit_ranged_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.file.write(
f"""
impl Ranged for crate::generic::{info.rust_sum_name}::<TextRange> {{
fn range(&self) -> TextRange {{
self.range
}}
}}
""".strip()
)
class LocatedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
if info.is_simple:
return
sum_match_arms = ""
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += f" Self::{variant_info.rust_name}(node) => node.range(),"
self.emit_type_alias(variant_info)
self.emit_located_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(), 0)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
self.emit_located_impl(info)
def emit_type_alias(self, info):
generics = "" if info.is_simple else "::<SourceRange>"
self.emit(f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};", 0)
self.emit("", 0)
def emit_located_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
self.range
}}
}}
"""
, 0)
class ChainOfVisitors:
def __init__(self, *visitors):
self.visitors = visitors
@ -956,6 +1051,7 @@ class ChainOfVisitors:
def write_ast_def(mod, type_info, f):
f.write("use crate::text_size::TextRange;")
StructVisitor(f, type_info).visit(mod)
@ -967,32 +1063,12 @@ def write_visitor_def(mod, type_info, f):
VisitorModuleVisitor(f, type_info).visit(mod)
def write_located_def(mod, type_info, f):
f.write(
textwrap.dedent(
"""
use rustpython_parser_core::source_code::SourceRange;
def write_ranged_def(mod, type_info, f):
RangedDefVisitor(f, type_info).visit(mod)
pub type Located<T> = super::generic::Attributed<T, SourceRange>;
"""
)
)
for info in type_info.values():
if info.empty_field:
continue
if info.has_user_data:
generics = "::<SourceRange>"
else:
generics = ""
f.write(
f"pub type {info.rust_sum_name} = super::generic::{info.rust_sum_name}{generics};\n"
)
if info.rust_suffix:
if info.rust_suffix == "Data" and not info.has_expr:
generics = ""
f.write(
f"pub type {info.rust_sum_name}{info.rust_suffix} = super::generic::{info.rust_sum_name}{info.rust_suffix}{generics};\n"
)
def write_located_def(mod, type_info, f):
LocatedDefVisitor(f, type_info).visit(mod)
def write_ast_mod(mod, type_info, f):
@ -1016,10 +1092,10 @@ def write_ast_mod(mod, type_info, f):
def main(
input_filename,
ast_dir,
module_filename,
dump_module=False,
input_filename,
ast_dir,
module_filename,
dump_module=False,
):
auto_gen_msg = AUTO_GEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
mod = asdl.parse(input_filename)
@ -1035,6 +1111,7 @@ def main(
for filename, write in [
("generic", write_ast_def),
("fold", write_fold_def),
("ranged", write_ranged_def),
("located", write_located_def),
("visitor", write_visitor_def),
]: