Move range to node (#23)

* black + clippy

* Fix module generation
This commit is contained in:
Jeong, YunWon 2023-05-15 16:20:22 +09:00 committed by GitHub
parent 192379cede
commit 718354673e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 170 additions and 101 deletions

View file

@ -258,7 +258,6 @@ class StructVisitor(EmitVisitor):
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.rust_type_defs = []
def visitModule(self, mod):
for dfn in mod.dfns:
@ -359,17 +358,17 @@ class StructVisitor(EmitVisitor):
typ = f"{typ}<R>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if (
field_type
and field_type.boxed
and (not (parent.product or field.seq) or field.opt)
field_type
and field_type.boxed
and (not (parent.product or field.seq) or field.opt)
):
typ = f"Box<{typ}>"
if field.opt or (
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict"
and field.name == "keys"
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict"
and field.name == "keys"
):
typ = f"Option<{typ}>"
if field.seq:
@ -579,9 +578,10 @@ class VisitorTraitDefVisitor(StructVisitor):
def emit_visitor(self, nodename, depth, has_node=True):
type_info = self.type_info[nodename]
node_type = type_info.rust_sum_name
generic, = self.apply_generics(nodename, "R")
(generic,) = self.apply_generics(nodename, "R")
self.emit(
f"fn visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{", depth
f"fn visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{",
depth,
)
if has_node:
self.emit(f"self.generic_visit_{type_info.sum_name}(node)", depth + 1)
@ -594,7 +594,7 @@ class VisitorTraitDefVisitor(StructVisitor):
node_type = type_info.rust_sum_name
else:
node_type = "()"
generic, = self.apply_generics(nodename, "R")
(generic,) = self.apply_generics(nodename, "R")
self.emit(
f"fn generic_visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{",
depth,
@ -677,7 +677,7 @@ class VisitorModuleVisitor(EmitVisitor):
VisitorTraitDefVisitor(self.file, self.type_info).visit(mod, depth)
class class_defVisitor(EmitVisitor):
class StdlibClassDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
@ -686,9 +686,9 @@ class class_defVisitor(EmitVisitor):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
struct_name = "NodeKind" + rust_type_name(name)
struct_name = "Node" + rust_type_name(name)
self.emit(
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "AstNode")]',
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "NodeAst")]',
depth,
)
self.emit(f"struct {struct_name};", depth)
@ -703,8 +703,12 @@ class class_defVisitor(EmitVisitor):
def visitProduct(self, product, name, depth):
self.gen_class_def(name, product.fields, product.attributes, depth)
def gen_class_def(self, name, fields, attrs, depth, base="AstNode"):
struct_name = "Node" + rust_type_name(name)
def gen_class_def(self, name, fields, attrs, depth, base=None):
if base is None:
base = "NodeAst"
struct_name = "Node" + rust_type_name(name)
else:
struct_name = base + rust_type_name(name)
self.emit(
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = {json.dumps(base)})]',
depth,
@ -735,7 +739,7 @@ class class_defVisitor(EmitVisitor):
self.emit("}", depth)
class ExtendModuleVisitor(EmitVisitor):
class StdlibExtendModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit(
@ -753,24 +757,24 @@ class ExtendModuleVisitor(EmitVisitor):
def visitSum(self, sum, name, depth):
rust_name = rust_type_name(name)
self.emit(
f"{json.dumps(name)} => NodeKind{rust_name}::make_class(&vm.ctx),", depth
)
self.emit(f"{json.dumps(name)} => Node{rust_name}::make_class(&vm.ctx),", depth)
for cons in sum.types:
self.visit(cons, depth)
self.visit(cons, depth, rust_name)
def visitConstructor(self, cons, depth):
self.gen_extension(cons.name, depth)
def visitConstructor(self, cons, depth, rust_name):
self.gen_extension(cons.name, depth, rust_name)
def visitProduct(self, product, name, depth):
self.gen_extension(name, depth)
def gen_extension(self, name, depth):
def gen_extension(self, name, depth, base=""):
rust_name = rust_type_name(name)
self.emit(f"{json.dumps(name)} => Node{rust_name}::make_class(&vm.ctx),", depth)
self.emit(
f"{json.dumps(name)} => Node{base}{rust_name}::make_class(&vm.ctx),", depth
)
class TraitImplVisitor(EmitVisitor):
class StdlibTraitImplVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
@ -779,45 +783,87 @@ class TraitImplVisitor(EmitVisitor):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
rust_name = enum_name = rust_type_name(name)
if sum.attributes:
rust_name = enum_name + "Kind"
rust_name = rust_type_name(name)
self.emit(f"impl NamedNode for ast::located::{rust_name} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit("// sum", depth)
self.emit(f"impl Node for ast::located::{rust_name} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
"fn ast_to_object(self, vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
self.emit("match self {", depth + 2)
for variant in sum.types:
self.constructor_to_object(variant, enum_name, rust_name, depth + 3)
self.emit("}", depth + 2)
simple = is_simple(sum)
if simple:
self.emit("let node_type = match self {", depth + 2)
for cons in sum.types:
self.emit(
f"ast::located::{rust_name}::{cons.name} => Node{rust_name}{cons.name}::static_type(),",
depth,
)
self.emit("};", depth + 3)
self.emit("NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()", depth + 2)
else:
self.emit("match self {", depth + 2)
for cons in sum.types:
self.emit(
f"ast::located::{rust_name}::{cons.name}(cons) => cons.ast_to_object(vm),",
depth + 3,
)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_sum_from_object(sum, name, enum_name, rust_name, depth + 2)
self.gen_sum_from_object(sum, name, rust_name, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def constructor_to_object(self, cons, enum_name, rust_name, depth):
self.emit(f"ast::located::{rust_name}::{cons.name}", depth)
if cons.fields:
fields_pattern = self.make_pattern(cons.fields)
self.emit(
f"( ast::located::{enum_name}{cons.name} {{ {fields_pattern} }} )",
depth,
)
self.emit(" => {", depth)
self.make_node(cons.name, cons.fields, depth + 1)
if not is_simple(sum):
for cons in sum.types:
self.visit(cons, sum, rust_name, depth)
def visitConstructor(self, cons, sum, sum_rust_name, depth):
rust_name = rust_type_name(cons.name)
self.emit("// constructor", depth)
self.emit(
f"impl NamedNode for ast::located::{sum_rust_name}{rust_name} {{", depth
)
self.emit(f"const NAME: &'static str = {json.dumps(cons.name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::located::{sum_rust_name}{rust_name} {{", depth)
fields_pattern = self.make_pattern(cons.fields)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
self.emit(
f"let ast::located::{sum_rust_name}{rust_name} {{ {fields_pattern} }} = self;",
depth,
)
self.make_node(cons.name, sum, cons.fields, depth + 2, sum_rust_name)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_product_from_object(
cons, cons.name, f"{sum_rust_name}{rust_name}", sum.attributes, depth + 2
)
self.emit("}", depth + 1)
self.emit("}", depth + 1)
def visitProduct(self, product, name, depth):
struct_name = rust_type_name(name)
self.emit("// product", depth)
self.emit(f"impl NamedNode for ast::located::{struct_name} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
@ -827,53 +873,57 @@ class TraitImplVisitor(EmitVisitor):
)
fields_pattern = self.make_pattern(product.fields)
self.emit(
f"let ast::located::{struct_name} {{ {fields_pattern} }} = self;", depth + 2
f"let ast::located::{struct_name} {{ {fields_pattern} }} = self;",
depth + 2,
)
self.make_node(name, product.fields, depth + 2)
self.make_node(name, product, product.fields, depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_product_from_object(product, name, struct_name, depth + 2)
self.gen_product_from_object(
product, name, struct_name, product.attributes, depth + 2
)
self.emit("}", depth + 1)
self.emit("}", depth)
def make_node(self, variant, fields, depth):
def make_node(self, variant, owner, fields, depth, base=""):
rust_variant = rust_type_name(variant)
self.emit(
f"let _node = AstNode.into_ref_with_type(_vm, Node{rust_variant}::static_type().to_owned()).unwrap();",
f"let node = NodeAst.into_ref_with_type(_vm, Node{base}{rust_variant}::static_type().to_owned()).unwrap();",
depth,
)
if fields:
self.emit("let _dict = _node.as_object().dict().unwrap();", depth)
if fields or owner.attributes:
self.emit("let dict = node.as_object().dict().unwrap();", depth)
for f in fields:
self.emit(
f"_dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();",
f"dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();",
depth,
)
self.emit("_node.into()", depth)
if owner.attributes:
self.emit("node_add_location(&dict, _range, _vm);", depth)
self.emit("node.into()", depth)
def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)
return "".join(f"{rust_field(f.name)}," for f in fields) + "range: _range"
def gen_sum_from_object(self, sum, sum_name, enum_name, rust_name, depth):
def gen_sum_from_object(self, sum, sum_name, rust_name, depth):
# if sum.attributes:
# self.extract_location(sum_name, depth)
self.emit("let _cls = _object.class();", depth)
self.emit("Ok(", depth)
for cons in sum.types:
self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth)
if cons.fields:
self.emit(
f"if _cls.is(Node{rust_name}{cons.name}::static_type()) {{", depth
)
self.emit(f"ast::located::{rust_name}::{cons.name}", depth + 1)
if not is_simple(sum):
self.emit(
f"ast::located::{rust_name}::{cons.name} (ast::located::{enum_name}{cons.name} {{",
f"(ast::located::{rust_name}{cons.name}::ast_from_object(_vm, _object)?)",
depth + 1,
)
self.gen_construction_fields(cons, sum_name, depth + 1)
self.emit("})", depth + 1)
else:
self.emit(f"ast::located::{rust_name}::{cons.name}", depth + 1)
self.emit("} else", depth)
self.emit("{", depth)
@ -881,12 +931,13 @@ class TraitImplVisitor(EmitVisitor):
self.emit(f"return Err(_vm.new_type_error({msg}));", depth + 1)
self.emit("})", depth)
def gen_product_from_object(self, product, product_name, struct_name, depth):
# if product.attributes:
# self.extract_location(product_name, depth)
def gen_product_from_object(
self, product, product_name, struct_name, has_attributes, depth
):
self.emit("Ok(", depth)
self.gen_construction(struct_name, product, product_name, depth + 1)
self.gen_construction(
struct_name, product, product_name, has_attributes, depth + 1
)
self.emit(")", depth)
def gen_construction_fields(self, cons, name, depth):
@ -896,9 +947,13 @@ class TraitImplVisitor(EmitVisitor):
depth + 1,
)
def gen_construction(self, cons_path, cons, name, depth):
def gen_construction(self, cons_path, cons, name, attributes, depth):
self.emit(f"ast::located::{cons_path} {{", depth)
self.gen_construction_fields(cons, name, depth + 1)
if attributes:
self.emit(f'range: range_from_object(_vm, _object, "{name}")?,', depth + 1)
else:
self.emit("range: Default::default(),", depth + 1)
self.emit("}", depth)
def extract_location(self, typename, depth):
@ -940,21 +995,26 @@ class RangedDefVisitor(EmitVisitor):
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += f" Self::{variant_info.rust_name}(node) => node.range(),"
sum_match_arms += (
f" Self::{variant_info.rust_name}(node) => node.range(),"
)
self.emit_ranged_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(f"""
impl Ranged for crate::{info.rust_sum_name} {{
fn range(&self) -> TextRange {{
match self {{
{sum_match_arms}
self.emit(
f"""
impl Ranged for crate::{info.rust_sum_name} {{
fn range(&self) -> TextRange {{
match self {{
{sum_match_arms}
}}
}}
}}
}}
""".lstrip(), 0)
""".lstrip(),
0,
)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
@ -996,22 +1056,27 @@ class LocatedDefVisitor(EmitVisitor):
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += f" Self::{variant_info.rust_name}(node) => node.range(),"
sum_match_arms += (
f" Self::{variant_info.rust_name}(node) => node.range(),"
)
self.emit_type_alias(variant_info)
self.emit_located_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
match self {{
{sum_match_arms}
self.emit(
f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
match self {{
{sum_match_arms}
}}
}}
}}
}}
""".lstrip(), 0)
""".lstrip(),
0,
)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
@ -1022,7 +1087,10 @@ class LocatedDefVisitor(EmitVisitor):
def emit_type_alias(self, info):
generics = "" if info.is_simple else "::<SourceRange>"
self.emit(f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};", 0)
self.emit(
f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};",
0,
)
self.emit("", 0)
def emit_located_impl(self, info):
@ -1036,8 +1104,9 @@ class LocatedDefVisitor(EmitVisitor):
self.range
}}
}}
"""
, 0)
""",
0,
)
class ChainOfVisitors:
@ -1084,18 +1153,18 @@ def write_ast_mod(mod, type_info, f):
)
c = ChainOfVisitors(
class_defVisitor(f, type_info),
TraitImplVisitor(f, type_info),
ExtendModuleVisitor(f, type_info),
StdlibClassDefVisitor(f, type_info),
StdlibTraitImplVisitor(f, type_info),
StdlibExtendModuleVisitor(f, type_info),
)
c.visit(mod)
def main(
input_filename,
ast_dir,
module_filename,
dump_module=False,
input_filename,
ast_dir,
module_filename,
dump_module=False,
):
auto_gen_msg = AUTO_GEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
mod = asdl.parse(input_filename)

View file

@ -1,4 +1,4 @@
use crate::{Constant, Excepthandler, Expr, Pattern, Stmt};
use crate::{Constant, Expr};
impl<R> Expr<R> {
/// Returns a short name for the node suitable for use in error messages.
@ -55,10 +55,10 @@ impl<R> Expr<R> {
}
#[cfg(target_arch = "x86_64")]
static_assertions::assert_eq_size!(Expr, [u8; 72]);
static_assertions::assert_eq_size!(crate::Expr, [u8; 72]);
#[cfg(target_arch = "x86_64")]
static_assertions::assert_eq_size!(Stmt, [u8; 136]);
static_assertions::assert_eq_size!(crate::Stmt, [u8; 136]);
#[cfg(target_arch = "x86_64")]
static_assertions::assert_eq_size!(Pattern, [u8; 96]);
static_assertions::assert_eq_size!(crate::Pattern, [u8; 96]);
#[cfg(target_arch = "x86_64")]
static_assertions::assert_eq_size!(Excepthandler, [u8; 64]);
static_assertions::assert_eq_size!(crate::Excepthandler, [u8; 64]);