No Attributed + custom for Range PoC

This commit is contained in:
Jeong YunWon 2023-05-12 03:41:41 +09:00 committed by Micha Reiser
parent ba0ae51e82
commit 904f5c8b37
No known key found for this signature in database
10 changed files with 1732 additions and 1184 deletions

View file

@ -84,6 +84,7 @@ class TypeInfo:
enum_name: Optional[str]
has_user_data: Optional[bool]
has_attributes: bool
is_simple: bool
empty_field: bool
children: set
boxed: bool
@ -95,6 +96,7 @@ class TypeInfo:
self.enum_name = None
self.has_user_data = None
self.has_attributes = False
self.is_simple = False
self.empty_field = False
self.children = set()
self.boxed = False
@ -104,6 +106,14 @@ class TypeInfo:
def __repr__(self):
return f"<TypeInfo: {self.name}>"
def needs_cfg(self, typeinfo):
if self.product:
return self.has_attributes
elif self.enum_name:
return typeinfo[self.enum_name].has_attributes
else:
return self.has_attributes
@property
def rust_name(self):
return rust_type_name(self.name)
@ -124,19 +134,6 @@ class TypeInfo:
name = rust_type_name(self.enum_name) + rust_name
return name
@property
def rust_suffix(self):
if self.product:
if self.has_attributes:
return "Data"
else:
return ""
else:
if self.has_attributes:
return "Kind"
else:
return ""
def determine_user_data(self, type_info, stack):
if self.name in stack:
return None
@ -208,6 +205,7 @@ class FindUserDataTypesVisitor(asdl.VisitorBase):
info = self.type_info[name]
if is_simple(sum):
info.has_user_data = False
info.is_simple = True
else:
for t in sum.types:
t_info = TypeInfo(t.name)
@ -277,6 +275,16 @@ class StructVisitor(EmitVisitor):
def emit_attrs(self, depth):
self.emit("#[derive(Clone, Debug, PartialEq)]", depth)
def emit_custom(self, has_attributes, depth):
if has_attributes:
self.emit("pub custom: U,", depth + 1)
else:
self.emit('#[cfg(feature = "more-attributes")]', depth + 1)
self.emit("pub custom: U,", depth + 1)
self.emit('#[cfg(not(feature = "more-attributes"))]', depth + 1)
self.emit("pub custom: std::marker::PhantomData<U>,", depth + 1)
def simple_sum(self, sum, name, depth):
rust_name = rust_type_name(name)
self.emit_attrs(depth)
@ -289,7 +297,6 @@ class StructVisitor(EmitVisitor):
def sum_with_constructors(self, sum, name, depth):
type_info = self.type_info[name]
suffix = type_info.rust_suffix
rust_name = rust_type_name(name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Attributed<>
@ -297,10 +304,9 @@ class StructVisitor(EmitVisitor):
for t in sum.types:
self.sum_subtype_struct(type_info, t, rust_name, depth)
generics, generics_applied = self.apply_generics(name, "U = ()", "U")
self.emit_attrs(depth)
self.emit("#[derive(is_macro::Is)]", depth)
self.emit(f"pub enum {rust_name}{suffix}{generics} {{", depth)
self.emit(f"pub enum {rust_name}<U> {{", depth)
needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types)
for t in sum.types:
if needs_escape:
@ -308,33 +314,29 @@ class StructVisitor(EmitVisitor):
f'#[is(name = "{rust_field_name(t.name)}_{rust_name.lower()}")]',
depth + 1,
)
(t_generics_applied,) = self.apply_generics(t.name, "U")
self.emit(
f"{t.name}({rust_name}{t.name}{t_generics_applied}),", depth + 1
f"{t.name}({rust_name}{t.name}<U>),", depth + 1
)
self.emit("}", depth)
if type_info.has_attributes:
self.emit(
f"pub type {rust_name}<U = ()> = Attributed<{rust_name}{suffix}{generics_applied}, U>;",
depth,
)
self.emit("", depth)
def sum_subtype_struct(self, sum_type_info, t, rust_name, depth):
self.emit_attrs(depth)
generics, generics_applied = self.apply_generics(t.name, "U = ()", "U")
payload_name = f"{rust_name}{t.name}"
self.emit(f"pub struct {payload_name}{generics} {{", depth)
self.emit(f"pub struct {payload_name}<U> {{", depth)
for f in t.fields:
self.visit(f, sum_type_info, "pub ", depth + 1, t.name)
self.emit("pub range: TextRange", depth + 1)
assert sum_type_info.has_attributes == self.type_info[t.name].needs_cfg(self.type_info)
self.emit_custom(sum_type_info.has_attributes, depth)
self.emit("}", depth)
self.emit(
textwrap.dedent(
f"""
impl{generics_applied} From<{payload_name}{generics_applied}> for {rust_name}{sum_type_info.rust_suffix}{generics_applied} {{
fn from(payload: {payload_name}{generics_applied}) -> Self {{
{rust_name}{sum_type_info.rust_suffix}::{t.name}(payload)
impl<U> From<{payload_name}<U>> for {rust_name}<U> {{
fn from(payload: {payload_name}<U>) -> Self {{
{rust_name}::{t.name}(payload)
}}
}}
"""
@ -342,12 +344,14 @@ class StructVisitor(EmitVisitor):
depth,
)
self.emit(f"impl{generics_applied} Ranged for {payload_name}{generics_applied} {{", depth)
self.emit("#[inline]", depth + 1)
self.emit("fn range(&self) -> TextRange {", depth + 1)
self.emit("self.range", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
# if not sum_type_info.has_attributes:
# self.emit('#[cfg(feature = "more-attributes")]', depth)
# self.emit(f"impl Ranged for {payload_name}<TextRange> {{", depth)
# self.emit("#[inline]", depth + 1)
# self.emit("fn range(&self) -> TextRange {", depth + 1)
# self.emit("self.custom", depth + 2)
# self.emit("}", depth + 1)
# self.emit("}", depth)
self.emit("", depth)
def visitConstructor(self, cons, parent, depth):
@ -362,7 +366,7 @@ class StructVisitor(EmitVisitor):
def visitField(self, field, parent, vis, depth, constructor=None):
typ = rust_type_name(field.type)
field_type = self.type_info.get(field.type)
if field_type and field_type.has_user_data:
if field_type and not field_type.is_simple:
typ = f"{typ}<U>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if (
@ -388,41 +392,23 @@ class StructVisitor(EmitVisitor):
def visitProduct(self, product, name, depth):
type_info = self.type_info[name]
generics, generics_applied = self.apply_generics(name, "U = ()", "U")
data_name = rust_name = rust_type_name(name)
if product.attributes:
data_name = rust_name + "Data"
product_name = rust_type_name(name)
self.emit_attrs(depth)
has_expr = product_has_expr(product)
if has_expr:
data_def = f"{data_name}{generics}"
else:
data_def = data_name
generics_applied = ""
self.emit(f"pub struct {data_def} {{", depth)
self.emit(f"pub struct {product_name}<U> {{", depth)
for f in product.fields:
self.visit(f, type_info, "pub ", depth + 1)
self.emit("pub range: TextRange", depth + 1)
assert bool(product.attributes) == type_info.needs_cfg(self.type_info)
self.emit_custom(product.attributes, depth + 1)
self.emit("}", depth)
self.emit(f"impl{generics_applied} Ranged for {data_name}{generics_applied} {{", depth)
self.emit("#[inline]", depth + 1)
self.emit("fn range(&self) -> TextRange {", depth + 1)
self.emit("self.range", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit("", depth)
if product.attributes:
# attributes should just be location info
if not has_expr:
generics_applied = ""
self.emit(
f"pub type {rust_name}<U = ()> = Attributed<{data_name}{generics_applied}, U>;",
depth,
)
# self.emit('#[cfg(feature = "more-attributes")]', depth)
# self.emit(f"impl Ranged for {product_name}<TextRange> {{", depth)
# self.emit("#[inline]", depth + 1)
# self.emit("fn range(&self) -> TextRange {", depth + 1)
# self.emit("self.custom", depth + 2)
# self.emit("}", depth + 1)
# self.emit("}", depth)
self.emit("", depth)
@ -511,7 +497,7 @@ class FoldImplVisitor(EmitVisitor):
self.emit("match node {", depth + 1)
for cons in sum.types:
fields_pattern = self.make_pattern(
enum_name, type_info.rust_suffix, cons.name, cons.fields, simple
enum_name, cons.name, cons.fields, simple
)
self.emit(
f"{fields_pattern[0]} {{ {fields_pattern[1]}}} {fields_pattern[2]} => {{",
@ -552,19 +538,19 @@ class FoldImplVisitor(EmitVisitor):
rust_name = struct_name + "Data"
else:
rust_name = struct_name
fields_pattern = self.make_pattern(rust_name, struct_name, None, product.fields, False)
fields_pattern = self.make_pattern(rust_name, struct_name, product.fields, False)
self.emit(f"let {rust_name} {{ {fields_pattern[1]} }} = node;", depth + 1)
self.gen_construction(rust_name, product.fields, "", depth + 1, False)
if has_attributes:
self.emit("})", depth)
self.emit("}", depth)
def make_pattern(self, rust_name, suffix, fieldname: str, fields, simple_sum: bool):
def make_pattern(self, rust_name, fieldname: str, fields, simple_sum: bool):
if fields or not simple_sum:
header = f"{rust_name}{suffix}::{fieldname}({rust_name}{fieldname}"
header = f"{rust_name}::{fieldname}({rust_name}{fieldname}"
footer = ")"
else:
header = f"{rust_name}{suffix}::{fieldname}"
header = f"{rust_name}::{fieldname}"
footer = ""
body = ",".join(rust_field(f.name) for f in fields)
@ -977,11 +963,6 @@ class ChainOfVisitors:
def write_ast_def(mod, type_info, f):
f.write("""
use crate::text_size::{TextRange};
use crate::Ranged;
""")
StructVisitor(f, type_info).visit(mod)
@ -996,33 +977,43 @@ def write_visitor_def(mod, type_info, f):
VisitorModuleVisitor(f, type_info).visit(mod)
def write_located_def(mod, type_info, f):
f.write(
textwrap.dedent(
"""
use rustpython_parser_core::source_code::SourceRange;
pub type Located<T> = super::generic::Attributed<T, SourceRange>;
"""
)
)
def write_ranged_def(mod, type_info, f):
for info in type_info.values():
if info.empty_field:
continue
if info.has_user_data:
if not info.is_simple:
if info.needs_cfg:
f.write('#[cfg(feature = "more-attributes")]')
f.write(f"""
impl Ranged for {info.rust_sum_name} {{
fn range(&self) -> TextRange {{
self.custom
}}
}}
""")
generics = "::<TextRange>"
else:
generics = ""
f.write(
f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};\n"
)
def write_located_def(mod, type_info, f):
for info in type_info.values():
if not info.is_simple:
if info.needs_cfg:
f.write('#[cfg(feature = "more-attributes")]')
f.write(f"""
impl Located for {info.rust_sum_name} {{
fn range(&self) -> SourceRange {{
self.custom
}}
}}
""")
generics = "::<SourceRange>"
else:
generics = ""
f.write(
f"pub type {info.rust_sum_name} = super::generic::{info.rust_sum_name}{generics};\n"
f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};\n"
)
if info.rust_suffix:
if info.rust_suffix == "Data" and not info.has_expr:
generics = ""
f.write(
f"pub type {info.rust_sum_name}{info.rust_suffix} = super::generic::{info.rust_sum_name}{info.rust_suffix}{generics};\n"
)
def write_ast_mod(mod, type_info, f):
f.write(
@ -1064,6 +1055,7 @@ def main(
for filename, write in [
("generic", write_ast_def),
("fold", write_fold_def),
("ranged", write_ranged_def),
("located", write_located_def),
("visitor", write_visitor_def),
]: