mirror of
https://github.com/RustPython/Parser.git
synced 2025-07-09 22:25:23 +00:00
Generate a visitor trait to ast_gen.rs
This commit is contained in:
parent
5cf85f0b9d
commit
4de9580b92
2 changed files with 153 additions and 1 deletions
153
ast/asdl_rs.py
153
ast/asdl_rs.py
|
@ -543,6 +543,153 @@ class FoldModuleVisitor(EmitVisitor):
|
|||
self.emit("}", depth)
|
||||
|
||||
|
||||
class VisitorStructsDefVisitor(StructVisitor):
|
||||
def visitModule(self, mod, depth):
|
||||
for dfn in mod.dfns:
|
||||
self.visit(dfn, depth)
|
||||
|
||||
def visitProduct(self, product, name, depth):
|
||||
pass
|
||||
|
||||
def visitSum(self, sum, name, depth):
|
||||
if not is_simple(sum):
|
||||
typeinfo = self.typeinfo[name]
|
||||
if not sum.attributes:
|
||||
return
|
||||
for t in sum.types:
|
||||
typename = t.name + "Node"
|
||||
|
||||
has_userdata = any(
|
||||
getattr(self.typeinfo.get(f.type), "has_userdata", False)
|
||||
for f in t.fields
|
||||
)
|
||||
self.emit(
|
||||
f"pub struct {typename}Data<{'U=()' if has_userdata else ''}> {{",
|
||||
depth,
|
||||
)
|
||||
for f in t.fields:
|
||||
self.visit(f, typeinfo, "pub ", depth + 1, t.name)
|
||||
self.emit("}", depth)
|
||||
self.emit(
|
||||
f"pub type {typename}<U = ()> = Located<{typename}Data<{'U' if has_userdata else ''}>, U>;",
|
||||
depth,
|
||||
)
|
||||
self.emit("", depth)
|
||||
|
||||
|
||||
class VisitorTraitDefVisitor(StructVisitor):
|
||||
def visitModule(self, mod, depth):
|
||||
self.emit("pub trait Visitor<U=()> {", depth)
|
||||
for dfn in mod.dfns:
|
||||
self.visit(dfn, depth + 1)
|
||||
self.emit("}", depth)
|
||||
|
||||
def visitType(self, type, depth=0):
|
||||
self.visit(type.value, type.name, depth)
|
||||
|
||||
def emit_visitor(self, nodename, rusttype, depth):
|
||||
self.emit(f"fn visit_{nodename}(&mut self, node: {rusttype}) {{", depth)
|
||||
self.emit(f"self.generic_visit_{nodename}(node);", depth + 1)
|
||||
self.emit("}", depth)
|
||||
|
||||
def emit_generic_visitor_signature(self, nodename, rusttype, depth):
|
||||
self.emit(f"fn generic_visit_{nodename}(&mut self, node: {rusttype}) {{", depth)
|
||||
|
||||
def emit_empty_generic_visitor(self, nodename, rusttype, depth):
|
||||
self.emit_generic_visitor_signature(nodename, rusttype, depth)
|
||||
self.emit("}", depth)
|
||||
|
||||
def simple_sum(self, sum, name, depth):
|
||||
rustname = get_rust_type(name)
|
||||
self.emit_visitor(name, rustname, depth)
|
||||
self.emit_empty_generic_visitor(name, rustname, depth)
|
||||
|
||||
def visit_match_for_type(self, enumname, type_, depth):
|
||||
self.emit(f"{enumname}::{type_.name} {{", depth)
|
||||
for field in type_.fields:
|
||||
self.emit(f"{rust_field(field.name)},", depth + 1)
|
||||
self.emit(f"}} => self.visit_{type_.name}(", depth)
|
||||
self.emit(f"{type_.name}Node {{", depth + 2)
|
||||
self.emit("location: node.location,", depth + 2)
|
||||
self.emit("end_location: node.end_location,", depth + 2)
|
||||
self.emit("custom: node.custom,", depth + 2)
|
||||
self.emit(f"node: {type_.name}NodeData {{", depth + 2)
|
||||
for field in type_.fields:
|
||||
self.emit(f"{rust_field(field.name)},", depth + 3)
|
||||
self.emit("},", depth + 2)
|
||||
self.emit("}", depth + 1)
|
||||
self.emit("),", depth)
|
||||
|
||||
def visit_sumtype(self, type_, depth):
|
||||
rustname = get_rust_type(type_.name) + "Node"
|
||||
self.emit_visitor(type_.name, rustname, depth)
|
||||
self.emit_generic_visitor_signature(type_.name, rustname, depth)
|
||||
for f in type_.fields:
|
||||
fieldname = rust_field(f.name)
|
||||
fieldtype = self.typeinfo.get(f.type)
|
||||
if not (fieldtype and fieldtype.has_userdata):
|
||||
continue
|
||||
|
||||
if f.opt:
|
||||
self.emit(f"if let Some(value) = node.node.{fieldname} {{", depth + 1)
|
||||
elif f.seq:
|
||||
iterable = f"node.node.{fieldname}"
|
||||
if type_.name == "Dict" and f.name == "keys":
|
||||
iterable = f"{iterable}.into_iter().flatten()"
|
||||
self.emit(f"for value in {iterable} {{", depth + 1)
|
||||
else:
|
||||
self.emit("{", depth + 1)
|
||||
self.emit(f"let value = node.node.{fieldname};", depth + 2)
|
||||
|
||||
variable = "value"
|
||||
if fieldtype.boxed and (not f.seq or f.opt):
|
||||
variable = "*" + variable
|
||||
self.emit(f"self.visit_{fieldtype.name}({variable});", depth + 2)
|
||||
|
||||
self.emit("}", depth + 1)
|
||||
|
||||
self.emit("}", depth)
|
||||
|
||||
def sum_with_constructors(self, sum, name, depth):
|
||||
if not sum.attributes:
|
||||
return
|
||||
|
||||
rustname = enumname = get_rust_type(name)
|
||||
if sum.attributes:
|
||||
enumname += "Kind"
|
||||
self.emit_visitor(name, rustname, depth)
|
||||
self.emit_generic_visitor_signature(name, rustname, depth)
|
||||
depth += 1
|
||||
self.emit("match node.node {", depth)
|
||||
for t in sum.types:
|
||||
self.visit_match_for_type(enumname, t, depth + 1)
|
||||
self.emit("}", depth)
|
||||
depth -= 1
|
||||
self.emit("}", depth)
|
||||
|
||||
# Now for the visitors for the types
|
||||
for t in sum.types:
|
||||
self.visit_sumtype(t, depth)
|
||||
|
||||
def visitProduct(self, product, name, depth):
|
||||
rusttype = get_rust_type(name)
|
||||
self.emit_visitor(name, rusttype, depth)
|
||||
self.emit_empty_generic_visitor(name, rusttype, depth)
|
||||
|
||||
|
||||
class VisitorModuleVisitor(EmitVisitor):
|
||||
def visitModule(self, mod):
|
||||
depth = 0
|
||||
self.emit('#[cfg(feature = "visitor")]', depth)
|
||||
self.emit("#[allow(unused_variables, non_snake_case)]", depth)
|
||||
self.emit("pub mod visitor {", depth)
|
||||
self.emit("use super::*;", depth + 1)
|
||||
VisitorStructsDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
|
||||
VisitorTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
|
||||
self.emit("}", depth)
|
||||
self.emit("", depth)
|
||||
|
||||
|
||||
class ClassDefVisitor(EmitVisitor):
|
||||
def visitModule(self, mod):
|
||||
for dfn in mod.dfns:
|
||||
|
@ -811,7 +958,11 @@ def write_generic_def(mod, typeinfo, f):
|
|||
)
|
||||
)
|
||||
|
||||
c = ChainOfVisitors(StructVisitor(f, typeinfo), FoldModuleVisitor(f, typeinfo))
|
||||
c = ChainOfVisitors(
|
||||
StructVisitor(f, typeinfo),
|
||||
FoldModuleVisitor(f, typeinfo),
|
||||
VisitorModuleVisitor(f, typeinfo),
|
||||
)
|
||||
c.visit(mod)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue