Add all node classes to ast module (#40)

This commit is contained in:
Jeong, YunWon 2023-05-16 16:55:26 +09:00 committed by GitHub
parent 0c7d16b61a
commit 53de75efc3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 1168 additions and 199 deletions

View file

@ -10,7 +10,7 @@ import re
from argparse import ArgumentParser
from pathlib import Path
from typing import Optional, Dict
from typing import Optional, Dict, Any
import asdl
@ -30,7 +30,18 @@ BUILTIN_INT_NAMES = {
"is_async": "bool",
}
RUST_KEYWORDS = {"if", "while", "for", "return", "match", "try", "await", "yield"}
RUST_KEYWORDS = {
"if",
"while",
"for",
"return",
"match",
"try",
"await",
"yield",
"in",
"mod",
}
def rust_field_name(name):
@ -80,34 +91,52 @@ def asdl_of(name, obj):
class TypeInfo:
name: str
type: asdl.Type
enum_name: Optional[str]
has_user_data: Optional[bool]
has_attributes: bool
is_simple: bool
empty_field: bool
children: set
fields: Optional[Any]
boxed: bool
product: bool
has_expr: bool = False
def __init__(self, name):
self.name = name
def __init__(self, type):
self.type = type
self.enum_name = None
self.has_user_data = None
self.has_attributes = False
self.is_simple = False
self.empty_field = False
self.children = set()
self.fields = None
self.boxed = False
self.product = False
self.product_has_expr = False
def __repr__(self):
return f"<TypeInfo: {self.name}>"
@property
def name(self):
return self.type.name
@property
def is_type(self):
return isinstance(self.type, asdl.Type)
@property
def is_product(self):
return self.is_type and isinstance(self.type.value, asdl.Product)
@property
def is_sum(self):
return self.is_type and isinstance(self.type.value, asdl.Sum)
@property
def has_expr(self):
return self.is_product and any(
f.type != "identifier" for f in self.type.value.fields
)
def no_cfg(self, typeinfo):
if self.product:
if self.is_product:
return self.has_attributes
elif self.enum_name:
return typeinfo[self.enum_name].has_attributes
@ -199,42 +228,46 @@ class FindUserDataTypesVisitor(asdl.VisitorBase):
info.determine_user_data(self.type_info, stack)
def visitType(self, type):
self.type_info[type.name] = TypeInfo(type.name)
self.visit(type.value, type.name)
self.type_info[type.name] = TypeInfo(type)
self.visit(type.value, type)
def visitSum(self, sum, name):
info = self.type_info[name]
if is_simple(sum):
def visitSum(self, sum, type):
info = self.type_info[type.name]
info.is_simple = is_simple(sum)
for cons in sum.types:
self.visit(cons, type, info.is_simple)
if info.is_simple:
info.has_user_data = False
info.is_simple = True
else:
for t in sum.types:
t_info = TypeInfo(t.name)
t_info.enum_name = name
t_info.empty_field = not t.fields
self.type_info[t.name] = t_info
self.add_children(t.name, t.fields)
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
# attributes means located, which has the `range: R` field
info.has_user_data = True
info.has_attributes = True
return
for t in sum.types:
self.add_children(t.name, t.fields)
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
# attributes means located, which has the `range: R` field
info.has_user_data = True
info.has_attributes = True
for variant in sum.types:
self.add_children(name, variant.fields)
self.add_children(type.name, variant.fields)
def visitProduct(self, product, name):
info = self.type_info[name]
def visitConstructor(self, cons, type, simple):
info = self.type_info[cons.name] = TypeInfo(cons)
info.enum_name = type.name
info.is_simple = simple
def visitProduct(self, product, type):
info = self.type_info[type.name]
if product.attributes:
# attributes means located, which has the `range: R` field
info.has_user_data = True
info.has_attributes = True
info.has_expr = product_has_expr(product)
if len(product.fields) > 2:
info.boxed = True
info.product = True
self.add_children(name, product.fields)
self.add_children(type.name, product.fields)
def add_children(self, name, fields):
self.type_info[name].children.update(
@ -249,10 +282,6 @@ def rust_field(field_name):
return field_name
def product_has_expr(product):
return any(f.type != "identifier" for f in product.fields)
class StructVisitor(EmitVisitor):
"""Visitor to generate type-defs for AST."""
@ -264,19 +293,19 @@ class StructVisitor(EmitVisitor):
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
self.visit(type.value, type, depth)
def visitSum(self, sum, name, depth):
def visitSum(self, sum, type, depth):
if is_simple(sum):
self.simple_sum(sum, name, depth)
self.simple_sum(sum, type, depth)
else:
self.sum_with_constructors(sum, name, depth)
self.sum_with_constructors(sum, type, depth)
(generics_applied,) = self.apply_generics(name, "R")
(generics_applied,) = self.apply_generics(type.name, "R")
self.emit(
f"""
impl{generics_applied} Node for {rust_type_name(name)}{generics_applied} {{
const NAME: &'static str = "{name}";
impl{generics_applied} Node for {rust_type_name(type.name)}{generics_applied} {{
const NAME: &'static str = "{type.name}";
const FIELD_NAMES: &'static [&'static str] = &[];
}}
""",
@ -292,19 +321,64 @@ class StructVisitor(EmitVisitor):
else:
self.emit("pub range: OptionalRange<R>,", depth + 1)
def simple_sum(self, sum, name, depth):
rust_name = rust_type_name(name)
def simple_sum(self, sum, type, depth):
rust_name = rust_type_name(type.name)
self.emit_attrs(depth)
self.emit("#[derive(is_macro::Is, Copy, Hash, Eq)]", depth)
self.emit(f"pub enum {rust_name} {{", depth)
for variant in sum.types:
self.emit(f"{variant.name},", depth + 1)
for cons in sum.types:
self.emit(f"{cons.name},", depth + 1)
self.emit("}", depth)
self.emit(f"impl {rust_name} {{", depth)
needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types)
if needs_escape:
prefix = rust_field_name(type.name) + "_"
else:
prefix = ""
for cons in sum.types:
self.emit(
textwrap.dedent(
f"""
#[inline]
pub const fn {prefix}{rust_field_name(cons.name)}(&self) -> Option<{rust_name}{cons.name}> {{
match self {{
{rust_name}::{cons.name} => Some({rust_name}{cons.name}),
_ => None,
}}
}}
"""
),
depth,
)
self.emit("}", depth)
self.emit("", depth)
def sum_with_constructors(self, sum, name, depth):
type_info = self.type_info[name]
rust_name = rust_type_name(name)
for cons in sum.types:
self.emit(
f"""
pub struct {rust_name}{cons.name};
impl From<{rust_name}{cons.name}> for {rust_name} {{
fn from(_: {rust_name}{cons.name}) -> Self {{
{rust_name}::{cons.name}
}}
}}
impl Node for {rust_name}{cons.name} {{
const NAME: &'static str = "{cons.name}";
const FIELD_NAMES: &'static [&'static str] = &[];
}}
impl std::cmp::PartialEq<{rust_name}> for {rust_name}{cons.name} {{
#[inline]
fn eq(&self, other: &{rust_name}) -> bool {{
matches!(other, {rust_name}::{cons.name})
}}
}}
""",
0,
)
def sum_with_constructors(self, sum, type, depth):
type_info = self.type_info[type.name]
rust_name = rust_type_name(type.name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Attributed<>
@ -376,7 +450,7 @@ class StructVisitor(EmitVisitor):
if (
field_type
and field_type.boxed
and (not (parent.product or field.seq) or field.opt)
and (not (parent.is_product or field.seq) or field.opt)
):
typ = f"Box<{typ}>"
if field.opt or (
@ -394,9 +468,9 @@ class StructVisitor(EmitVisitor):
name = rust_field(field.name)
self.emit(f"{vis}{name}: {typ},", depth)
def visitProduct(self, product, name, depth):
type_info = self.type_info[name]
product_name = rust_type_name(name)
def visitProduct(self, product, type, depth):
type_info = self.type_info[type.name]
product_name = rust_type_name(type.name)
self.emit_attrs(depth)
self.emit(f"pub struct {product_name}<R = TextRange> {{", depth)
@ -410,7 +484,7 @@ class StructVisitor(EmitVisitor):
self.emit(
f"""
impl<R> Node for {product_name}<R> {{
const NAME: &'static str = "{name}";
const NAME: &'static str = "{type.name}";
const FIELD_NAMES: &'static [&'static str] = &[
{', '.join(field_names)}
];
@ -711,6 +785,158 @@ class VisitorModuleVisitor(EmitVisitor):
VisitorTraitDefVisitor(self.file, self.type_info).visit(mod, depth)
class RangedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
if info.is_simple:
for ty in sum.types:
variant_info = self.type_info[ty.name]
self.emit_type_alias(variant_info)
return
sum_match_arms = ""
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += (
f" Self::{variant_info.rust_name}(node) => node.range(),"
)
self.emit_type_alias(variant_info)
self.emit_ranged_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Ranged for crate::{info.rust_sum_name} {{
fn range(&self) -> TextRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(),
0,
)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
self.emit_ranged_impl(info)
def emit_type_alias(self, info):
return # disable
generics = "" if info.is_simple else "::<TextRange>"
self.emit(
f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};",
0,
)
self.emit("", 0)
def emit_ranged_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.file.write(
f"""
impl Ranged for crate::generic::{info.rust_sum_name}::<TextRange> {{
fn range(&self) -> TextRange {{
self.range
}}
}}
""".strip()
)
class LocatedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
if info.is_simple:
for ty in sum.types:
variant_info = self.type_info[ty.name]
self.emit_type_alias(variant_info)
return
sum_match_arms = ""
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += (
f" Self::{variant_info.rust_name}(node) => node.range(),"
)
self.emit_type_alias(variant_info)
self.emit_located_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(),
0,
)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
self.emit_located_impl(info)
def emit_type_alias(self, info):
generics = "" if info.is_simple else "::<SourceRange>"
self.emit(
f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};",
0,
)
self.emit("", 0)
def emit_located_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
self.range
}}
}}
""",
0,
)
class StdlibClassDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
@ -836,7 +1062,10 @@ class StdlibTraitImplVisitor(EmitVisitor):
depth,
)
self.emit("};", depth + 3)
self.emit("NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()", depth + 2)
self.emit(
"NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()",
depth + 2,
)
else:
self.emit("match self {", depth + 2)
for cons in sum.types:
@ -1011,138 +1240,6 @@ class StdlibTraitImplVisitor(EmitVisitor):
return f"Node::ast_from_object(_vm, get_node_field(_vm, &_object, {name}, {json.dumps(typename)})?)?"
class RangedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
if info.is_simple:
return
sum_match_arms = ""
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += (
f" Self::{variant_info.rust_name}(node) => node.range(),"
)
self.emit_ranged_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Ranged for crate::{info.rust_sum_name} {{
fn range(&self) -> TextRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(),
0,
)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.emit_ranged_impl(info)
def emit_ranged_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.file.write(
f"""
impl Ranged for crate::generic::{info.rust_sum_name}::<TextRange> {{
fn range(&self) -> TextRange {{
self.range
}}
}}
""".strip()
)
class LocatedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
if info.is_simple:
return
sum_match_arms = ""
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += (
f" Self::{variant_info.rust_name}(node) => node.range(),"
)
self.emit_type_alias(variant_info)
self.emit_located_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(),
0,
)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
self.emit_located_impl(info)
def emit_type_alias(self, info):
generics = "" if info.is_simple else "::<SourceRange>"
self.emit(
f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};",
0,
)
self.emit("", 0)
def emit_located_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
self.range
}}
}}
""",
0,
)
class ChainOfVisitors:
def __init__(self, *visitors):
self.visitors = visitors
@ -1211,16 +1308,18 @@ def main(
type_info = {}
FindUserDataTypesVisitor(type_info).visit(mod)
from functools import partial as p
for filename, write in [
("generic", write_ast_def),
("fold", write_fold_def),
("ranged", write_ranged_def),
("located", write_located_def),
("visitor", write_visitor_def),
("generic", p(write_ast_def, mod, type_info)),
("fold", p(write_fold_def, mod, type_info)),
("ranged", p(write_ranged_def, mod, type_info)),
("located", p(write_located_def, mod, type_info)),
("visitor", p(write_visitor_def, mod, type_info)),
]:
with (ast_dir / f"{filename}.rs").open("w") as f:
f.write(auto_gen_msg)
write(mod, type_info, f)
write(f)
with module_filename.open("w") as module_file:
module_file.write(auto_gen_msg)