spell check ast/asdl_rs.py

This commit is contained in:
Jeong YunWon 2023-05-10 18:05:39 +09:00
parent 75f6ce1ae5
commit d8822d1091
3 changed files with 183 additions and 176 deletions

View file

@ -71,3 +71,4 @@ jobs:
'core/**/*.rs'
'literal/**/*.rs'
'parser/**/*.rs'
'ast/asdl_rs.py'

View file

@ -1,3 +1,5 @@
# spell-checker:words dfn dfns
#! /usr/bin/env python
"""Generate Rust code from an ASDL description."""
@ -12,7 +14,7 @@ from typing import Optional, Dict
import asdl
TABSIZE = 4
AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n"
AUTO_GEN_MESSAGE = "// File automatically generated by {}.\n\n"
builtin_type_mapping = {
"identifier": "Ident",
@ -23,7 +25,7 @@ builtin_type_mapping = {
assert builtin_type_mapping.keys() == asdl.builtin_types
def get_rust_type(name):
def rust_type_name(name):
"""Return a string for the C name of the type.
This function special cases the default types provided by asdl.
@ -66,7 +68,7 @@ def asdl_of(name, obj):
class TypeInfo:
name: str
enum_name: Optional[str]
has_userdata: Optional[bool]
has_user_data: Optional[bool]
has_attributes: bool
empty_field: bool
children: set
@ -77,7 +79,7 @@ class TypeInfo:
def __init__(self, name):
self.name = name
self.enum_name = None
self.has_userdata = None
self.has_user_data = None
self.has_attributes = False
self.empty_field = False
self.children = set()
@ -90,7 +92,7 @@ class TypeInfo:
@property
def rust_name(self):
return get_rust_type(self.name)
return rust_type_name(self.name)
@property
def sum_name(self):
@ -101,11 +103,11 @@ class TypeInfo:
@property
def rust_sum_name(self):
rust_name = get_rust_type(self.name)
rust_name = rust_type_name(self.name)
if self.enum_name is None:
return rust_name
else:
name = get_rust_type(self.enum_name) + rust_name
name = rust_type_name(self.enum_name) + rust_name
return name
@property
@ -121,30 +123,30 @@ class TypeInfo:
else:
return ""
def determine_userdata(self, typeinfo, stack):
def determine_user_data(self, type_info, stack):
if self.name in stack:
return None
stack.add(self.name)
for child, child_seq in self.children:
if child in asdl.builtin_types:
continue
childinfo = typeinfo[child]
child_has_userdata = childinfo.determine_userdata(typeinfo, stack)
if self.has_userdata is None and child_has_userdata is True:
self.has_userdata = True
child_info = type_info[child]
child_has_user_data = child_info.determine_user_data(type_info, stack)
if self.has_user_data is None and child_has_user_data is True:
self.has_user_data = True
stack.remove(self.name)
return self.has_userdata
return self.has_user_data
class TypeInfoMixin:
typeinfo: Dict[str, TypeInfo]
type_info: Dict[str, TypeInfo]
def has_userdata(self, typ):
return self.typeinfo[typ].has_userdata
def has_user_data(self, typ):
return self.type_info[typ].has_user_data
def get_generics(self, typ, *generics):
if self.has_userdata(typ):
def apply_generics(self, typ, *generics):
if self.has_user_data(typ):
return [f"<{g}>" for g in generics]
else:
return ["" for g in generics]
@ -153,9 +155,9 @@ class TypeInfoMixin:
class EmitVisitor(asdl.VisitorBase, TypeInfoMixin):
"""Visit that emits lines"""
def __init__(self, file, typeinfo):
def __init__(self, file, type_info):
self.file = file
self.typeinfo = typeinfo
self.type_info = type_info
self.identifiers = set()
super(EmitVisitor, self).__init__()
@ -172,48 +174,48 @@ class EmitVisitor(asdl.VisitorBase, TypeInfoMixin):
self.file.write(line + "\n")
class FindUserdataTypesVisitor(asdl.VisitorBase):
def __init__(self, typeinfo):
self.typeinfo = typeinfo
class FindUserDataTypesVisitor(asdl.VisitorBase):
def __init__(self, type_info):
self.type_info = type_info
super().__init__()
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
stack = set()
for info in self.typeinfo.values():
info.determine_userdata(self.typeinfo, stack)
for info in self.type_info.values():
info.determine_user_data(self.type_info, stack)
def visitType(self, type):
self.typeinfo[type.name] = TypeInfo(type.name)
self.type_info[type.name] = TypeInfo(type.name)
self.visit(type.value, type.name)
def visitSum(self, sum, name):
info = self.typeinfo[name]
info = self.type_info[name]
if is_simple(sum):
info.has_userdata = False
info.has_user_data = False
else:
for t in sum.types:
t_info = TypeInfo(t.name)
t_info.enum_name = name
t_info.empty_field = not t.fields
self.typeinfo[t.name] = t_info
self.type_info[t.name] = t_info
self.add_children(t.name, t.fields)
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
# attributes means located, which has the `custom: U` field
info.has_userdata = True
info.has_user_data = True
info.has_attributes = True
for variant in sum.types:
self.add_children(name, variant.fields)
def visitProduct(self, product, name):
info = self.typeinfo[name]
info = self.type_info[name]
if product.attributes:
# attributes means located, which has the `custom: U` field
info.has_userdata = True
info.has_user_data = True
info.has_attributes = True
info.has_expr = product_has_expr(product)
if len(product.fields) > 2:
@ -222,7 +224,9 @@ class FindUserdataTypesVisitor(asdl.VisitorBase):
self.add_children(name, product.fields)
def add_children(self, name, fields):
self.typeinfo[name].children.update((field.type, field.seq) for field in fields)
self.type_info[name].children.update(
(field.type, field.seq) for field in fields
)
def rust_field(field_name):
@ -237,7 +241,7 @@ def product_has_expr(product):
class StructVisitor(EmitVisitor):
"""Visitor to generate typedefs for AST."""
"""Visitor to generate type-defs for AST."""
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
@ -260,59 +264,59 @@ class StructVisitor(EmitVisitor):
self.emit("#[derive(Clone, Debug, PartialEq)]", depth)
def simple_sum(self, sum, name, depth):
rustname = get_rust_type(name)
rust_name = rust_type_name(name)
self.emit_attrs(depth)
self.emit(f"pub enum {rustname} {{", depth)
self.emit(f"pub enum {rust_name} {{", depth)
for variant in sum.types:
self.emit(f"{variant.name},", depth + 1)
self.emit("}", depth)
self.emit("", depth)
def sum_with_constructors(self, sum, name, depth):
typeinfo = self.typeinfo[name]
suffix = typeinfo.rust_suffix
rustname = get_rust_type(name)
type_info = self.type_info[name]
suffix = type_info.rust_suffix
rust_name = rust_type_name(name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Attributed<>
for t in sum.types:
if not t.fields:
continue
self.sum_subtype_struct(typeinfo, t, rustname, depth)
self.sum_subtype_struct(type_info, t, rust_name, depth)
generics, generics_applied = self.get_generics(name, "U = ()", "U")
generics, generics_applied = self.apply_generics(name, "U = ()", "U")
self.emit_attrs(depth)
self.emit(f"pub enum {rustname}{suffix}{generics} {{", depth)
self.emit(f"pub enum {rust_name}{suffix}{generics} {{", depth)
for t in sum.types:
if t.fields:
(t_generics_applied,) = self.get_generics(t.name, "U")
(t_generics_applied,) = self.apply_generics(t.name, "U")
self.emit(
f"{t.name}({rustname}{t.name}{t_generics_applied}),", depth + 1
f"{t.name}({rust_name}{t.name}{t_generics_applied}),", depth + 1
)
else:
self.emit(f"{t.name},", depth + 1)
self.emit("}", depth)
if typeinfo.has_attributes:
if type_info.has_attributes:
self.emit(
f"pub type {rustname}<U = ()> = Attributed<{rustname}{suffix}{generics_applied}, U>;",
f"pub type {rust_name}<U = ()> = Attributed<{rust_name}{suffix}{generics_applied}, U>;",
depth,
)
self.emit("", depth)
def sum_subtype_struct(self, sum_typeinfo, t, rustname, depth):
def sum_subtype_struct(self, sum_type_info, t, rust_name, depth):
self.emit_attrs(depth)
generics, generics_applied = self.get_generics(t.name, "U = ()", "U")
payload_name = f"{rustname}{t.name}"
generics, generics_applied = self.apply_generics(t.name, "U = ()", "U")
payload_name = f"{rust_name}{t.name}"
self.emit(f"pub struct {payload_name}{generics} {{", depth)
for f in t.fields:
self.visit(f, sum_typeinfo, "pub ", depth + 1, t.name)
self.visit(f, sum_type_info, "pub ", depth + 1, t.name)
self.emit("}", depth)
self.emit(
textwrap.dedent(
f"""
impl{generics_applied} From<{payload_name}{generics_applied}> for {rustname}{sum_typeinfo.rust_suffix}{generics_applied} {{
impl{generics_applied} From<{payload_name}{generics_applied}> for {rust_name}{sum_type_info.rust_suffix}{generics_applied} {{
fn from(payload: {payload_name}{generics_applied}) -> Self {{
{rustname}{sum_typeinfo.rust_suffix}::{t.name}(payload)
{rust_name}{sum_type_info.rust_suffix}::{t.name}(payload)
}}
}}
"""
@ -330,14 +334,14 @@ class StructVisitor(EmitVisitor):
self.emit(f"{cons.name},", depth)
def visitField(self, field, parent, vis, depth, constructor=None):
typ = get_rust_type(field.type)
fieldtype = self.typeinfo.get(field.type)
if fieldtype and fieldtype.has_userdata:
typ = rust_type_name(field.type)
field_type = self.type_info.get(field.type)
if field_type and field_type.has_user_data:
typ = f"{typ}<U>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if (
fieldtype
and fieldtype.boxed
field_type
and field_type.boxed
and (not (parent.product or field.seq) or field.opt)
):
typ = f"Box<{typ}>"
@ -355,27 +359,27 @@ class StructVisitor(EmitVisitor):
self.emit(f"{vis}{name}: {typ},", depth)
def visitProduct(self, product, name, depth):
typeinfo = self.typeinfo[name]
generics, generics_applied = self.get_generics(name, "U = ()", "U")
dataname = rustname = get_rust_type(name)
type_info = self.type_info[name]
generics, generics_applied = self.apply_generics(name, "U = ()", "U")
data_name = rust_name = rust_type_name(name)
if product.attributes:
dataname = rustname + "Data"
data_name = rust_name + "Data"
self.emit_attrs(depth)
has_expr = product_has_expr(product)
if has_expr:
datadef = f"{dataname}{generics}"
data_def = f"{data_name}{generics}"
else:
datadef = dataname
self.emit(f"pub struct {datadef} {{", depth)
data_def = data_name
self.emit(f"pub struct {data_def} {{", depth)
for f in product.fields:
self.visit(f, typeinfo, "pub ", depth + 1)
self.visit(f, type_info, "pub ", depth + 1)
self.emit("}", depth)
if product.attributes:
# attributes should just be location info
if not has_expr:
generics_applied = ""
self.emit(
f"pub type {rustname}<U = ()> = Attributed<{dataname}{generics_applied}, U>;",
f"pub type {rust_name}<U = ()> = Attributed<{data_name}{generics_applied}, U>;",
depth,
)
self.emit("", depth)
@ -411,10 +415,10 @@ class FoldTraitDefVisitor(EmitVisitor):
def visitType(self, type, depth):
name = type.name
apply_u, apply_target_u = self.get_generics(name, "U", "Self::TargetU")
enumname = get_rust_type(name)
apply_u, apply_target_u = self.apply_generics(name, "U", "Self::TargetU")
enum_name = rust_type_name(name)
self.emit(
f"fn fold_{name}(&mut self, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, Self::Error> {{",
f"fn fold_{name}(&mut self, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, Self::Error> {{",
depth,
)
self.emit(f"fold_{name}(self, node)", depth + 1)
@ -439,14 +443,14 @@ class FoldImplVisitor(EmitVisitor):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
typeinfo = self.typeinfo[name]
apply_t, apply_u, apply_target_u = self.get_generics(
type_info = self.type_info[name]
apply_t, apply_u, apply_target_u = self.apply_generics(
name, "T", "U", "F::TargetU"
)
enumname = get_rust_type(name)
enum_name = rust_type_name(name)
self.emit(f"impl<T, U> Foldable<T, U> for {enumname}{apply_t} {{", depth)
self.emit(f"type Mapped = {enumname}{apply_u};", depth + 1)
self.emit(f"impl<T, U> Foldable<T, U> for {enum_name}{apply_t} {{", depth)
self.emit(f"type Mapped = {enum_name}{apply_u};", depth + 1)
self.emit(
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
depth + 1,
@ -456,16 +460,16 @@ class FoldImplVisitor(EmitVisitor):
self.emit("}", depth)
self.emit(
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, F::Error> {{",
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, F::Error> {{",
depth,
)
if typeinfo.has_attributes:
if type_info.has_attributes:
self.emit("fold_attributed(folder, node, |folder, node| {", depth)
self.emit("match node {", depth + 1)
for cons in sum.types:
fields_pattern = self.make_pattern(
enumname, typeinfo.rust_suffix, cons.name, cons.fields
enum_name, type_info.rust_suffix, cons.name, cons.fields
)
self.emit(
f"{fields_pattern[0]} {{ {fields_pattern[1]} }} {fields_pattern[2]} => {{",
@ -476,19 +480,19 @@ class FoldImplVisitor(EmitVisitor):
)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
if typeinfo.has_attributes:
if type_info.has_attributes:
self.emit("})", depth)
self.emit("}", depth)
def visitProduct(self, product, name, depth):
apply_t, apply_u, apply_target_u = self.get_generics(
apply_t, apply_u, apply_target_u = self.apply_generics(
name, "T", "U", "F::TargetU"
)
structname = get_rust_type(name)
struct_name = rust_type_name(name)
has_attributes = bool(product.attributes)
self.emit(f"impl<T, U> Foldable<T, U> for {structname}{apply_t} {{", depth)
self.emit(f"type Mapped = {structname}{apply_u};", depth + 1)
self.emit(f"impl<T, U> Foldable<T, U> for {struct_name}{apply_t} {{", depth)
self.emit(f"type Mapped = {struct_name}{apply_u};", depth + 1)
self.emit(
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
depth + 1,
@ -498,27 +502,27 @@ class FoldImplVisitor(EmitVisitor):
self.emit("}", depth)
self.emit(
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {structname}{apply_u}) -> Result<{structname}{apply_target_u}, F::Error> {{",
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {struct_name}{apply_u}) -> Result<{struct_name}{apply_target_u}, F::Error> {{",
depth,
)
if has_attributes:
self.emit("fold_attributed(folder, node, |folder, node| {", depth)
rustname = structname + "Data"
rust_name = struct_name + "Data"
else:
rustname = structname
fields_pattern = self.make_pattern(rustname, structname, None, product.fields)
self.emit(f"let {rustname} {{ {fields_pattern[1]} }} = node;", depth + 1)
self.gen_construction(rustname, product.fields, "", depth + 1)
rust_name = struct_name
fields_pattern = self.make_pattern(rust_name, struct_name, None, product.fields)
self.emit(f"let {rust_name} {{ {fields_pattern[1]} }} = node;", depth + 1)
self.gen_construction(rust_name, product.fields, "", depth + 1)
if has_attributes:
self.emit("})", depth)
self.emit("}", depth)
def make_pattern(self, rustname, suffix, fieldname, fields):
def make_pattern(self, rust_name, suffix, fieldname, fields):
if fields:
header = f"{rustname}{suffix}::{fieldname}({rustname}{fieldname}"
header = f"{rust_name}{suffix}::{fieldname}({rust_name}{fieldname}"
footer = ")"
else:
header = f"{rustname}{suffix}::{fieldname}"
header = f"{rust_name}{suffix}::{fieldname}"
footer = ""
body = ",".join(rust_field(f.name) for f in fields)
@ -536,24 +540,24 @@ class FoldModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit("use crate::fold_helpers::Foldable;", depth)
FoldTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth)
FoldImplVisitor(self.file, self.typeinfo).visit(mod, depth)
FoldTraitDefVisitor(self.file, self.type_info).visit(mod, depth)
FoldImplVisitor(self.file, self.type_info).visit(mod, depth)
class VisitorTraitDefVisitor(StructVisitor):
def full_name(self, name):
typeinfo = self.typeinfo[name]
if typeinfo.enum_name:
return f"{typeinfo.enum_name}_{name}"
type_info = self.type_info[name]
if type_info.enum_name:
return f"{type_info.enum_name}_{name}"
else:
return name
def node_type_name(self, name):
typeinfo = self.typeinfo[name]
if typeinfo.enum_name:
return f"{get_rust_type(typeinfo.enum_name)}{get_rust_type(name)}"
type_info = self.type_info[name]
if type_info.enum_name:
return f"{rust_type_name(type_info.enum_name)}{rust_type_name(name)}"
else:
return get_rust_type(name)
return rust_type_name(name)
def visitModule(self, mod, depth):
self.emit("pub trait Visitor<U=()> {", depth)
@ -566,27 +570,27 @@ class VisitorTraitDefVisitor(StructVisitor):
self.visit(type.value, type.name, depth)
def emit_visitor(self, nodename, depth, has_node=True):
typeinfo = self.typeinfo[nodename]
type_info = self.type_info[nodename]
if has_node:
node_type = typeinfo.rust_sum_name
node_type = type_info.rust_sum_name
node_value = "node"
else:
node_type = "()"
node_value = "()"
self.emit(
f"fn visit_{typeinfo.sum_name}(&mut self, node: {node_type}) {{", depth
f"fn visit_{type_info.sum_name}(&mut self, node: {node_type}) {{", depth
)
self.emit(f"self.generic_visit_{typeinfo.sum_name}({node_value})", depth + 1)
self.emit(f"self.generic_visit_{type_info.sum_name}({node_value})", depth + 1)
self.emit("}", depth)
def emit_generic_visitor_signature(self, nodename, depth, has_node=True):
typeinfo = self.typeinfo[nodename]
type_info = self.type_info[nodename]
if has_node:
node_type = typeinfo.rust_sum_name
node_type = type_info.rust_sum_name
else:
node_type = "()"
self.emit(
f"fn generic_visit_{typeinfo.sum_name}(&mut self, node: {node_type}) {{",
f"fn generic_visit_{type_info.sum_name}(&mut self, node: {node_type}) {{",
depth,
)
@ -598,8 +602,8 @@ class VisitorTraitDefVisitor(StructVisitor):
self.emit_visitor(name, depth)
self.emit_empty_generic_visitor(name, depth)
def visit_match_for_type(self, nodename, rustname, type_, depth):
self.emit(f"{rustname}::{type_.name}", depth)
def visit_match_for_type(self, nodename, rust_name, type_, depth):
self.emit(f"{rust_name}::{type_.name}", depth)
if type_.fields:
self.emit("(data)", depth)
data = "data"
@ -607,13 +611,13 @@ class VisitorTraitDefVisitor(StructVisitor):
data = "()"
self.emit(f"=> self.visit_{nodename}_{type_.name}({data}),", depth)
def visit_sumtype(self, name, type_, depth):
def visit_sum_type(self, name, type_, depth):
self.emit_visitor(type_.name, depth, has_node=type_.fields)
self.emit_generic_visitor_signature(type_.name, depth, has_node=type_.fields)
for f in type_.fields:
fieldname = rust_field(f.name)
fieldtype = self.typeinfo.get(f.type)
if not (fieldtype and fieldtype.has_userdata):
field_type = self.type_info.get(f.type)
if not (field_type and field_type.has_user_data):
continue
if f.opt:
@ -628,10 +632,10 @@ class VisitorTraitDefVisitor(StructVisitor):
self.emit(f"let value = node.{fieldname};", depth + 2)
variable = "value"
if fieldtype.boxed and (not f.seq or f.opt):
if field_type.boxed and (not f.seq or f.opt):
variable = "*" + variable
typeinfo = self.typeinfo[fieldtype.name]
self.emit(f"self.visit_{typeinfo.sum_name}({variable});", depth + 2)
type_info = self.type_info[field_type.name]
self.emit(f"self.visit_{type_info.sum_name}({variable});", depth + 2)
self.emit("}", depth + 1)
@ -641,22 +645,22 @@ class VisitorTraitDefVisitor(StructVisitor):
if not sum.attributes:
return
rustname = enumname = get_rust_type(name)
rust_name = enum_name = rust_type_name(name)
if sum.attributes:
rustname = enumname + "Kind"
rust_name = enum_name + "Kind"
self.emit_visitor(name, depth)
self.emit_generic_visitor_signature(name, depth)
depth += 1
self.emit("match node.node {", depth)
for t in sum.types:
self.visit_match_for_type(name, rustname, t, depth + 1)
self.visit_match_for_type(name, rust_name, t, depth + 1)
self.emit("}", depth)
depth -= 1
self.emit("}", depth)
# Now for the visitors for the types
for t in sum.types:
self.visit_sumtype(name, t, depth)
self.visit_sum_type(name, t, depth)
def visitProduct(self, product, name, depth):
self.emit_visitor(name, depth)
@ -667,10 +671,10 @@ class VisitorModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit("#[allow(unused_variables, non_snake_case)]", depth)
VisitorTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth)
VisitorTraitDefVisitor(self.file, self.type_info).visit(mod, depth)
class ClassDefVisitor(EmitVisitor):
class class_defVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
@ -679,32 +683,32 @@ class ClassDefVisitor(EmitVisitor):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
structname = "NodeKind" + get_rust_type(name)
struct_name = "NodeKind" + rust_type_name(name)
self.emit(
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "AstNode")]',
depth,
)
self.emit(f"struct {structname};", depth)
self.emit(f"struct {struct_name};", depth)
self.emit("#[pyclass(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {structname} {{}}", depth)
self.emit(f"impl {struct_name} {{}}", depth)
for cons in sum.types:
self.visit(cons, sum.attributes, structname, depth)
self.visit(cons, sum.attributes, struct_name, depth)
def visitConstructor(self, cons, attrs, base, depth):
self.gen_classdef(cons.name, cons.fields, attrs, depth, base)
self.gen_class_def(cons.name, cons.fields, attrs, depth, base)
def visitProduct(self, product, name, depth):
self.gen_classdef(name, product.fields, product.attributes, depth)
self.gen_class_def(name, product.fields, product.attributes, depth)
def gen_classdef(self, name, fields, attrs, depth, base="AstNode"):
structname = "Node" + get_rust_type(name)
def gen_class_def(self, name, fields, attrs, depth, base="AstNode"):
struct_name = "Node" + rust_type_name(name)
self.emit(
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = {json.dumps(base)})]',
depth,
)
self.emit(f"struct {structname};", depth)
self.emit(f"struct {struct_name};", depth)
self.emit("#[pyclass(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {structname} {{", depth)
self.emit(f"impl {struct_name} {{", depth)
self.emit("#[extend_class]", depth + 1)
self.emit(
"fn extend_class_with_fields(ctx: &Context, class: &'static Py<PyType>) {",
@ -745,7 +749,7 @@ class ExtendModuleVisitor(EmitVisitor):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
rust_name = get_rust_type(name)
rust_name = rust_type_name(name)
self.emit(
f"{json.dumps(name)} => NodeKind{rust_name}::make_class(&vm.ctx),", depth
)
@ -759,7 +763,7 @@ class ExtendModuleVisitor(EmitVisitor):
self.gen_extension(name, depth)
def gen_extension(self, name, depth):
rust_name = get_rust_type(name)
rust_name = rust_type_name(name)
self.emit(f"{json.dumps(name)} => Node{rust_name}::make_class(&vm.ctx),", depth)
@ -772,56 +776,57 @@ class TraitImplVisitor(EmitVisitor):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
rustname = enumname = get_rust_type(name)
rust_name = enum_name = rust_type_name(name)
if sum.attributes:
rustname = enumname + "Kind"
rust_name = enum_name + "Kind"
self.emit(f"impl NamedNode for ast::located::{rustname} {{", depth)
self.emit(f"impl NamedNode for ast::located::{rust_name} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::located::{rustname} {{", depth)
self.emit(f"impl Node for ast::located::{rust_name} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
self.emit("match self {", depth + 2)
for variant in sum.types:
self.constructor_to_object(variant, enumname, rustname, depth + 3)
self.constructor_to_object(variant, enum_name, rust_name, depth + 3)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_sum_fromobj(sum, name, enumname, rustname, depth + 2)
self.gen_sum_from_object(sum, name, enum_name, rust_name, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def constructor_to_object(self, cons, enumname, rustname, depth):
self.emit(f"ast::located::{rustname}::{cons.name}", depth)
def constructor_to_object(self, cons, enum_name, rust_name, depth):
self.emit(f"ast::located::{rust_name}::{cons.name}", depth)
if cons.fields:
fields_pattern = self.make_pattern(cons.fields)
self.emit(
f"( ast::located::{enumname}{cons.name} {{ {fields_pattern} }} )", depth
f"( ast::located::{enum_name}{cons.name} {{ {fields_pattern} }} )",
depth,
)
self.emit(" => {", depth)
self.make_node(cons.name, cons.fields, depth + 1)
self.emit("}", depth)
def visitProduct(self, product, name, depth):
structname = get_rust_type(name)
struct_name = rust_type_name(name)
if product.attributes:
structname += "Data"
struct_name += "Data"
self.emit(f"impl NamedNode for ast::located::{structname} {{", depth)
self.emit(f"impl NamedNode for ast::located::{struct_name} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::located::{structname} {{", depth)
self.emit(f"impl Node for ast::located::{struct_name} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
fields_pattern = self.make_pattern(product.fields)
self.emit(
f"let ast::located::{structname} {{ {fields_pattern} }} = self;", depth + 2
f"let ast::located::{struct_name} {{ {fields_pattern} }} = self;", depth + 2
)
self.make_node(name, product.fields, depth + 2)
self.emit("}", depth + 1)
@ -829,12 +834,12 @@ class TraitImplVisitor(EmitVisitor):
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_product_fromobj(product, name, structname, depth + 2)
self.gen_product_from_object(product, name, struct_name, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def make_node(self, variant, fields, depth):
rust_variant = get_rust_type(variant)
rust_variant = rust_type_name(variant)
self.emit(
f"let _node = AstNode.into_ref_with_type(_vm, Node{rust_variant}::static_type().to_owned()).unwrap();",
depth,
@ -851,9 +856,9 @@ class TraitImplVisitor(EmitVisitor):
def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)
def gen_sum_fromobj(self, sum, sumname, enumname, rustname, depth):
def gen_sum_from_object(self, sum, sum_name, enum_name, rust_name, depth):
# if sum.attributes:
# self.extract_location(sumname, depth)
# self.extract_location(sum_name, depth)
self.emit("let _cls = _object.class();", depth)
self.emit("Ok(", depth)
@ -861,26 +866,26 @@ class TraitImplVisitor(EmitVisitor):
self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth)
if cons.fields:
self.emit(
f"ast::located::{rustname}::{cons.name} (ast::located::{enumname}{cons.name} {{",
f"ast::located::{rust_name}::{cons.name} (ast::located::{enum_name}{cons.name} {{",
depth + 1,
)
self.gen_construction_fields(cons, sumname, depth + 1)
self.gen_construction_fields(cons, sum_name, depth + 1)
self.emit("})", depth + 1)
else:
self.emit(f"ast::located::{rustname}::{cons.name}", depth + 1)
self.emit(f"ast::located::{rust_name}::{cons.name}", depth + 1)
self.emit("} else", depth)
self.emit("{", depth)
msg = f'format!("expected some sort of {sumname}, but got {{}}",_object.repr(_vm)?)'
msg = f'format!("expected some sort of {sum_name}, but got {{}}",_object.repr(_vm)?)'
self.emit(f"return Err(_vm.new_type_error({msg}));", depth + 1)
self.emit("})", depth)
def gen_product_fromobj(self, product, prodname, structname, depth):
def gen_product_from_object(self, product, product_name, struct_name, depth):
# if product.attributes:
# self.extract_location(prodname, depth)
# self.extract_location(product_name, depth)
self.emit("Ok(", depth)
self.gen_construction(structname, product, prodname, depth + 1)
self.gen_construction(struct_name, product, product_name, depth + 1)
self.emit(")", depth)
def gen_construction_fields(self, cons, name, depth):
@ -926,19 +931,19 @@ class ChainOfVisitors:
v.emit("", 0)
def write_ast_def(mod, typeinfo, f):
StructVisitor(f, typeinfo).visit(mod)
def write_ast_def(mod, type_info, f):
StructVisitor(f, type_info).visit(mod)
def write_fold_def(mod, typeinfo, f):
FoldModuleVisitor(f, typeinfo).visit(mod)
def write_fold_def(mod, type_info, f):
FoldModuleVisitor(f, type_info).visit(mod)
def write_visitor_def(mod, typeinfo, f):
VisitorModuleVisitor(f, typeinfo).visit(mod)
def write_visitor_def(mod, type_info, f):
VisitorModuleVisitor(f, type_info).visit(mod)
def write_located_def(mod, typeinfo, f):
def write_located_def(mod, type_info, f):
f.write(
textwrap.dedent(
"""
@ -948,10 +953,10 @@ def write_located_def(mod, typeinfo, f):
"""
)
)
for info in typeinfo.values():
for info in type_info.values():
if info.empty_field:
continue
if info.has_userdata:
if info.has_user_data:
generics = "::<SourceRange>"
else:
generics = ""
@ -966,7 +971,7 @@ def write_located_def(mod, typeinfo, f):
)
def write_ast_mod(mod, typeinfo, f):
def write_ast_mod(mod, type_info, f):
f.write(
textwrap.dedent(
"""
@ -979,9 +984,9 @@ def write_ast_mod(mod, typeinfo, f):
)
c = ChainOfVisitors(
ClassDefVisitor(f, typeinfo),
TraitImplVisitor(f, typeinfo),
ExtendModuleVisitor(f, typeinfo),
class_defVisitor(f, type_info),
TraitImplVisitor(f, type_info),
ExtendModuleVisitor(f, type_info),
)
c.visit(mod)
@ -992,7 +997,7 @@ def main(
module_filename,
dump_module=False,
):
auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
auto_gen_msg = AUTO_GEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
mod = asdl.parse(input_filename)
if dump_module:
print("Parsed Module:")
@ -1000,8 +1005,8 @@ def main(
if not asdl.check(mod):
sys.exit(1)
typeinfo = {}
FindUserdataTypesVisitor(typeinfo).visit(mod)
type_info = {}
FindUserDataTypesVisitor(type_info).visit(mod)
for filename, write in [
("generic", write_ast_def),
@ -1011,11 +1016,11 @@ def main(
]:
with (ast_dir / f"{filename}.rs").open("w") as f:
f.write(auto_gen_msg)
write(mod, typeinfo, f)
write(mod, type_info, f)
with module_filename.open("w") as module_file:
module_file.write(auto_gen_msg)
write_ast_mod(mod, typeinfo, module_file)
write_ast_mod(mod, type_info, module_file)
print(f"{ast_dir}, {module_filename} regenerated.")

View file

@ -1,2 +1,3 @@
#!/bin/bash
cspell "ast/**/*.rs" "literal/**/*.rs" "core/**/*.rs" "parser/**/*.rs"
cspell ast/asdl_rs.py