Remove redundant types

This commit is contained in:
Jeong YunWon 2023-05-07 23:19:39 +09:00
parent 243ca16b34
commit e000b1c304

View file

@ -68,6 +68,7 @@ class TypeInfo:
enum_name: Optional[str]
has_userdata: Optional[bool]
has_attributes: bool
empty_field: bool
children: set
boxed: bool
product: bool
@ -78,6 +79,7 @@ class TypeInfo:
self.enum_name = None
self.has_userdata = None
self.has_attributes = False
self.empty_field = False
self.children = set()
self.boxed = False
self.product = False
@ -192,10 +194,9 @@ class FindUserdataTypesVisitor(asdl.VisitorBase):
info.has_userdata = False
else:
for t in sum.types:
if not t.fields:
continue
t_info = TypeInfo(t.name)
t_info.enum_name = name
t_info.empty_field = not t.fields
self.typeinfo[t.name] = t_info
self.add_children(t.name, t.fields)
if len(sum.types) > 1:
@ -543,43 +544,24 @@ class FoldModuleVisitor(EmitVisitor):
self.emit("}", depth)
class VisitorStructsDefVisitor(StructVisitor):
def visitModule(self, mod, depth):
for dfn in mod.dfns:
self.visit(dfn, depth)
def visitProduct(self, product, name, depth):
pass
def visitSum(self, sum, name, depth):
if not is_simple(sum):
typeinfo = self.typeinfo[name]
if not sum.attributes:
return
for t in sum.types:
typename = t.name + "Node"
has_userdata = any(
getattr(self.typeinfo.get(f.type), "has_userdata", False)
for f in t.fields
)
self.emit(
f"pub struct {typename}Data<{'U=()' if has_userdata else ''}> {{",
depth,
)
for f in t.fields:
self.visit(f, typeinfo, "pub ", depth + 1, t.name)
self.emit("}", depth)
self.emit(
f"pub type {typename}<U = ()> = Located<{typename}Data<{'U' if has_userdata else ''}>, U>;",
depth,
)
self.emit("", depth)
class VisitorTraitDefVisitor(StructVisitor):
def full_name(self, name):
typeinfo = self.typeinfo[name]
if typeinfo.enum_name:
return f"{typeinfo.enum_name}_{name}"
else:
return name
def node_type_name(self, name):
typeinfo = self.typeinfo[name]
if typeinfo.enum_name:
return f"{get_rust_type(typeinfo.enum_name)}{get_rust_type(name)}"
else:
return get_rust_type(name)
def visitModule(self, mod, depth):
self.emit("pub trait Visitor<U=()> {", depth)
for dfn in mod.dfns:
self.visit(dfn, depth + 1)
self.emit("}", depth)
@ -587,46 +569,46 @@ class VisitorTraitDefVisitor(StructVisitor):
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def emit_visitor(self, nodename, rusttype, depth):
self.emit(f"fn visit_{nodename}(&mut self, node: {rusttype}) {{", depth)
self.emit(f"self.generic_visit_{nodename}(node);", depth + 1)
def emit_visitor(self, nodename, depth, has_node=True):
typeinfo = self.typeinfo[nodename]
if has_node:
node_type = typeinfo.rust_sum_name
node_value = "node"
else:
node_type = "()"
node_value = "()"
self.emit(f"fn visit_{typeinfo.sum_name}(&mut self, node: {node_type}) {{", depth)
self.emit(f"self.generic_visit_{typeinfo.sum_name}({node_value})", depth + 1)
self.emit("}", depth)
def emit_generic_visitor_signature(self, nodename, rusttype, depth):
self.emit(f"fn generic_visit_{nodename}(&mut self, node: {rusttype}) {{", depth)
def emit_generic_visitor_signature(self, nodename, depth, has_node=True):
typeinfo = self.typeinfo[nodename]
if has_node:
node_type = typeinfo.rust_sum_name
else:
node_type = "()"
self.emit(f"fn generic_visit_{typeinfo.sum_name}(&mut self, node: {node_type}) {{", depth)
def emit_empty_generic_visitor(self, nodename, rusttype, depth):
self.emit_generic_visitor_signature(nodename, rusttype, depth)
def emit_empty_generic_visitor(self, nodename, depth):
self.emit_generic_visitor_signature(nodename, depth)
self.emit("}", depth)
def simple_sum(self, sum, name, depth):
enumname = get_rust_type(name)
self.emit_visitor(name, enumname, depth)
self.emit_empty_generic_visitor(name, enumname, depth)
self.emit_visitor(name, depth)
self.emit_empty_generic_visitor(name, depth)
def visit_match_for_type(self, enumname, rustname, type_, depth):
def visit_match_for_type(self, nodename, rustname, type_, depth):
self.emit(f"{rustname}::{type_.name}", depth)
if type_.fields:
self.emit(f"({enumname}{type_.name} {{", depth)
for field in type_.fields:
self.emit(f"{rust_field(field.name)},", depth + 1)
self.emit("})", depth)
self.emit(f"=> self.visit_{type_.name}(", depth)
self.emit(f"{type_.name}Node {{", depth + 2)
self.emit("location: node.location,", depth + 2)
self.emit("end_location: node.end_location,", depth + 2)
self.emit("custom: node.custom,", depth + 2)
self.emit(f"node: {type_.name}NodeData {{", depth + 2)
for field in type_.fields:
self.emit(f"{rust_field(field.name)},", depth + 3)
self.emit("},", depth + 2)
self.emit("}", depth + 1)
self.emit("),", depth)
self.emit("(data)", depth)
data = "data"
else:
data = "()"
self.emit(f"=> self.visit_{nodename}_{type_.name}({data}),", depth)
def visit_sumtype(self, type_, depth):
rustname = get_rust_type(type_.name) + "Node"
self.emit_visitor(type_.name , rustname, depth)
self.emit_generic_visitor_signature(type_.name, rustname, depth)
def visit_sumtype(self, name, type_, depth):
self.emit_visitor(type_.name, depth, has_node=type_.fields)
self.emit_generic_visitor_signature(type_.name, depth, has_node=type_.fields)
for f in type_.fields:
fieldname = rust_field(f.name)
fieldtype = self.typeinfo.get(f.type)
@ -634,20 +616,21 @@ class VisitorTraitDefVisitor(StructVisitor):
continue
if f.opt:
self.emit(f"if let Some(value) = node.node.{fieldname} {{", depth + 1)
self.emit(f"if let Some(value) = node.{fieldname} {{", depth + 1)
elif f.seq:
iterable = f"node.node.{fieldname}"
iterable = f"node.{fieldname}"
if type_.name == "Dict" and f.name == "keys":
iterable = f"{iterable}.into_iter().flatten()"
self.emit(f"for value in {iterable} {{", depth + 1)
else:
self.emit("{", depth + 1)
self.emit(f"let value = node.node.{fieldname};", depth + 2)
self.emit(f"let value = node.{fieldname};", depth + 2)
variable = "value"
if fieldtype.boxed and (not f.seq or f.opt):
variable = "*" + variable
self.emit(f"self.visit_{fieldtype.name}({variable});", depth + 2)
typeinfo = self.typeinfo[fieldtype.name]
self.emit(f"self.visit_{typeinfo.sum_name}({variable});", depth + 2)
self.emit("}", depth + 1)
@ -660,24 +643,23 @@ class VisitorTraitDefVisitor(StructVisitor):
rustname = enumname = get_rust_type(name)
if sum.attributes:
rustname = enumname + "Kind"
self.emit_visitor(name, enumname, depth)
self.emit_generic_visitor_signature(name, enumname, depth)
self.emit_visitor(name, depth)
self.emit_generic_visitor_signature(name, depth)
depth += 1
self.emit("match node.node {", depth)
for t in sum.types:
self.visit_match_for_type(enumname, rustname, t, depth + 1)
self.visit_match_for_type(name, rustname, t, depth + 1)
self.emit("}", depth)
depth -= 1
self.emit("}", depth)
# Now for the visitors for the types
for t in sum.types:
self.visit_sumtype(t, depth)
self.visit_sumtype(name, t, depth)
def visitProduct(self, product, name, depth):
rusttype = get_rust_type(name)
self.emit_visitor(name, rusttype, depth)
self.emit_empty_generic_visitor(name, rusttype, depth)
self.emit_visitor(name, depth)
self.emit_empty_generic_visitor(name, depth)
class VisitorModuleVisitor(EmitVisitor):
@ -687,7 +669,6 @@ class VisitorModuleVisitor(EmitVisitor):
self.emit("#[allow(unused_variables, non_snake_case)]", depth)
self.emit("pub mod visitor {", depth)
self.emit("use super::*;", depth + 1)
VisitorStructsDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
VisitorTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
self.emit("}", depth)
self.emit("", depth)
@ -980,6 +961,8 @@ def write_located_def(typeinfo, f):
)
)
for info in typeinfo.values():
if info.empty_field:
continue
if info.has_userdata:
generics = "::<SourceRange>"
else: