Generate all located variants, generate enum implementations

This commit is contained in:
Micha Reiser 2023-05-13 17:43:07 +02:00
parent 7e48ec46b9
commit bbfaf17b0b
No known key found for this signature in database
6 changed files with 480 additions and 134 deletions

View file

@ -1,6 +1,6 @@
# spell-checker:words dfn dfns
#! /usr/bin/env python
# ! /usr/bin/env python
"""Generate Rust code from an ASDL description."""
import sys
@ -362,17 +362,17 @@ class StructVisitor(EmitVisitor):
typ = f"{typ}<R>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if (
field_type
and field_type.boxed
and (not (parent.product or field.seq) or field.opt)
field_type
and field_type.boxed
and (not (parent.product or field.seq) or field.opt)
):
typ = f"Box<{typ}>"
if field.opt or (
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict"
and field.name == "keys"
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict"
and field.name == "keys"
):
typ = f"Option<{typ}>"
if field.seq:
@ -939,16 +939,32 @@ class RangedDefVisitor(EmitVisitor):
if info.is_simple:
return
sum_match_arms = ""
for ty in sum.types:
info = self.type_info[ty.name]
self.make_ranged_impl(info)
variant_info = self.type_info[ty.name]
sum_match_arms += f" Self::{variant_info.rust_name}(node) => node.range(),"
self.emit_ranged_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "more-attributes")]', 0)
self.emit(f"""
impl Ranged for crate::{info.rust_sum_name} {{
fn range(&self) -> TextRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(), 0)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.make_ranged_impl(info)
self.emit_ranged_impl(info)
def make_ranged_impl(self, info):
def emit_ranged_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "more-attributes")]', 0)
@ -962,6 +978,7 @@ class RangedDefVisitor(EmitVisitor):
""".strip()
)
class LocatedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
@ -973,36 +990,58 @@ class LocatedDefVisitor(EmitVisitor):
def visitSum(self, sum, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
if info.is_simple:
return
sum_match_arms = ""
for ty in sum.types:
info = self.type_info[ty.name]
self.make_located_impl(info)
variant_info = self.type_info[ty.name]
sum_match_arms += f" Self::{variant_info.rust_name}(node) => node.range(),"
self.emit_type_alias(variant_info)
self.emit_located_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "more-attributes")]', 0)
self.emit(f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(), 0)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.make_located_impl(info)
self.emit_type_alias(info)
self.emit_located_impl(info)
def make_located_impl(self, info):
def emit_type_alias(self, info):
generics = "" if info.is_simple else "::<SourceRange>"
self.emit(f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};", 0)
self.emit("", 0)
def emit_located_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "more-attributes")]', 0)
self.emit(f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}::<SourceRange>;", 0)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "more-attributes")]', 0)
self.file.write(
self.emit(
f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
self.range
}}
}}
""".strip()
)
"""
, 0)
class ChainOfVisitors:
def __init__(self, *visitors):
@ -1035,6 +1074,7 @@ def write_ranged_def(mod, type_info, f):
def write_located_def(mod, type_info, f):
LocatedDefVisitor(f, type_info).visit(mod)
def write_ast_mod(mod, type_info, f):
f.write(
textwrap.dedent(
@ -1056,10 +1096,10 @@ def write_ast_mod(mod, type_info, f):
def main(
input_filename,
ast_dir,
module_filename,
dump_module=False,
input_filename,
ast_dir,
module_filename,
dump_module=False,
):
auto_gen_msg = AUTO_GEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
mod = asdl.parse(input_filename)