Separate byteoffset ast and located ast

This commit is contained in:
Jeong YunWon 2023-05-06 21:35:43 +09:00
parent f47dfca4e3
commit a14e43e03a
21 changed files with 893 additions and 562 deletions

View file

@ -7,6 +7,8 @@ import textwrap
from argparse import ArgumentParser
from pathlib import Path
from typing import Optional, Dict
from attr import dataclass
import asdl
@ -62,38 +64,62 @@ def asdl_of(name, obj):
return "{} = {}".format(name, types)
class EmitVisitor(asdl.VisitorBase):
"""Visit that emits lines"""
def __init__(self, file):
self.file = file
self.identifiers = set()
super(EmitVisitor, self).__init__()
def emit_identifier(self, name):
name = str(name)
if name in self.identifiers:
return
self.emit("_Py_IDENTIFIER(%s);" % name, 0)
self.identifiers.add(name)
def emit(self, line, depth):
if line:
line = (" " * TABSIZE * depth) + line
self.file.write(line + "\n")
class TypeInfo:
name: str
enum_name: Optional[str]
has_userdata: Optional[bool]
has_attributes: bool
children: set
boxed: bool
product: bool
has_expr: bool = False
def __init__(self, name):
self.name = name
self.enum_name = None
self.has_userdata = None
self.has_attributes = False
self.children = set()
self.boxed = False
self.product = False
self.product_has_expr = False
def __repr__(self):
return f"<TypeInfo: {self.name}>"
@property
def rust_name(self):
return get_rust_type(self.name)
@property
def sum_name(self):
if self.enum_name is None:
return self.name
else:
return f"{self.enum_name}_{self.name}"
@property
def rust_sum_name(self):
rust_name = get_rust_type(self.name)
if self.enum_name is None:
return rust_name
else:
name = get_rust_type(self.enum_name) + rust_name
return name
@property
def rust_suffix(self):
if self.product:
if self.has_attributes:
return "Data"
else:
return ""
else:
if self.has_attributes:
return "Kind"
else:
return ""
def determine_userdata(self, typeinfo, stack):
if self.name in stack:
return None
@ -110,6 +136,41 @@ class TypeInfo:
return self.has_userdata
class TypeInfoMixin:
typeinfo: Dict[str, TypeInfo]
def has_userdata(self, typ):
return self.typeinfo[typ].has_userdata
def get_generics(self, typ, *generics):
if self.has_userdata(typ):
return [f"<{g}>" for g in generics]
else:
return ["" for g in generics]
class EmitVisitor(asdl.VisitorBase, TypeInfoMixin):
"""Visit that emits lines"""
def __init__(self, file, typeinfo):
self.file = file
self.typeinfo = typeinfo
self.identifiers = set()
super(EmitVisitor, self).__init__()
def emit_identifier(self, name):
name = str(name)
if name in self.identifiers:
return
self.emit("_Py_IDENTIFIER(%s);" % name, 0)
self.identifiers.add(name)
def emit(self, line, depth):
if line:
line = (" " * TABSIZE * depth) + line
self.file.write(line + "\n")
class FindUserdataTypesVisitor(asdl.VisitorBase):
def __init__(self, typeinfo):
self.typeinfo = typeinfo
@ -132,21 +193,29 @@ class FindUserdataTypesVisitor(asdl.VisitorBase):
info.has_userdata = False
else:
for t in sum.types:
self.typeinfo[t.name] = TypeInfo(t.name)
if not t.fields:
continue
t_info = TypeInfo(t.name)
t_info.enum_name = name
self.typeinfo[t.name] = t_info
self.add_children(t.name, t.fields)
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
# attributes means Located, which has the `custom: U` field
# attributes means located, which has the `custom: U` field
info.has_userdata = True
info.has_attributes = True
for variant in sum.types:
self.add_children(name, variant.fields)
def visitProduct(self, product, name):
info = self.typeinfo[name]
if product.attributes:
# attributes means Located, which has the `custom: U` field
# attributes means located, which has the `custom: U` field
info.has_userdata = True
info.has_attributes = True
info.has_expr = product_has_expr(product)
if len(product.fields) > 2:
info.boxed = True
info.product = True
@ -163,24 +232,17 @@ def rust_field(field_name):
return field_name
class TypeInfoEmitVisitor(EmitVisitor):
def __init__(self, file, typeinfo):
self.typeinfo = typeinfo
super().__init__(file)
def has_userdata(self, typ):
return self.typeinfo[typ].has_userdata
def get_generics(self, typ, *generics):
if self.has_userdata(typ):
return [f"<{g}>" for g in generics]
else:
return ["" for g in generics]
def product_has_expr(product):
return any(f.type != "identifier" for f in product.fields)
class StructVisitor(TypeInfoEmitVisitor):
class StructVisitor(EmitVisitor):
"""Visitor to generate typedefs for AST."""
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.rust_type_defs = []
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
@ -208,57 +270,56 @@ class StructVisitor(TypeInfoEmitVisitor):
def sum_with_constructors(self, sum, name, depth):
typeinfo = self.typeinfo[name]
enumname = rustname = get_rust_type(name)
suffix = typeinfo.rust_suffix
rustname = get_rust_type(name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Located<>
if sum.attributes:
enumname = rustname + "Kind"
# can just wrap it in Attributed<>
for t in sum.types:
if not t.fields:
continue
self.emit_attrs(depth)
self.typeinfo[t] = TypeInfo(t)
t_generics, t_generics_applied = self.get_generics(t.name, "U = ()", "U")
payload_name = f"{rustname}{t.name}"
self.emit(f"pub struct {payload_name}{t_generics} {{", depth)
for f in t.fields:
self.visit(f, typeinfo, "pub ", depth + 1, t.name)
self.emit("}", depth)
self.emit(
textwrap.dedent(
f"""
impl{t_generics_applied} From<{payload_name}{t_generics_applied}> for {enumname}{t_generics_applied} {{
fn from(payload: {payload_name}{t_generics_applied}) -> Self {{
{enumname}::{t.name}(payload)
}}
}}
"""
),
depth,
)
self.sum_subtype_struct(typeinfo, t, rustname, depth)
generics, generics_applied = self.get_generics(name, "U = ()", "U")
self.emit_attrs(depth)
self.emit(f"pub enum {enumname}{generics} {{", depth)
self.emit(f"pub enum {rustname}{suffix}{generics} {{", depth)
for t in sum.types:
if t.fields:
t_generics, t_generics_applied = self.get_generics(
t.name, "U = ()", "U"
)
(t_generics_applied,) = self.get_generics(t.name, "U")
self.emit(
f"{t.name}({rustname}{t.name}{t_generics_applied}),", depth + 1
)
else:
self.emit(f"{t.name},", depth + 1)
self.emit("}", depth)
if sum.attributes:
if typeinfo.has_attributes:
self.emit(
f"pub type {rustname}<U = ()> = Located<{enumname}{generics_applied}, U>;",
f"pub type {rustname}<U = ()> = Attributed<{rustname}{suffix}{generics_applied}, U>;",
depth,
)
self.emit("", depth)
def sum_subtype_struct(self, sum_typeinfo, t, rustname, depth):
self.emit_attrs(depth)
generics, generics_applied = self.get_generics(t.name, "U = ()", "U")
payload_name = f"{rustname}{t.name}"
self.emit(f"pub struct {payload_name}{generics} {{", depth)
for f in t.fields:
self.visit(f, sum_typeinfo, "pub ", depth + 1, t.name)
self.emit("}", depth)
self.emit(
textwrap.dedent(
f"""
impl{generics_applied} From<{payload_name}{generics_applied}> for {rustname}{sum_typeinfo.rust_suffix}{generics_applied} {{
fn from(payload: {payload_name}{generics_applied}) -> Self {{
{rustname}{sum_typeinfo.rust_suffix}::{t.name}(payload)
}}
}}
"""
),
depth,
)
def visitConstructor(self, cons, parent, depth):
if cons.fields:
self.emit(f"{cons.name} {{", depth)
@ -300,7 +361,7 @@ class StructVisitor(TypeInfoEmitVisitor):
if product.attributes:
dataname = rustname + "Data"
self.emit_attrs(depth)
has_expr = any(f.type != "identifier" for f in product.fields)
has_expr = product_has_expr(product)
if has_expr:
datadef = f"{dataname}{generics}"
else:
@ -314,20 +375,24 @@ class StructVisitor(TypeInfoEmitVisitor):
if not has_expr:
generics_applied = ""
self.emit(
f"pub type {rustname}<U = ()> = Located<{dataname}{generics_applied}, U>;",
f"pub type {rustname}<U = ()> = Attributed<{dataname}{generics_applied}, U>;",
depth,
)
self.emit("", depth)
class FoldTraitDefVisitor(TypeInfoEmitVisitor):
class FoldTraitDefVisitor(EmitVisitor):
def visitModule(self, mod, depth):
self.emit("pub trait Fold<U> {", depth)
self.emit("type TargetU;", depth + 1)
self.emit("type Error;", depth + 1)
self.emit(
"fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error>;",
depth + 2,
depth + 1,
)
self.emit(
"fn map_located<T>(&mut self, located: Attributed<T, U>) -> Result<Attributed<T, Self::TargetU>, Self::Error> { let custom = self.map_user(located.custom)?; Ok(Attributed { range: located.range, custom, node: located.node }) }",
depth + 1,
)
for dfn in mod.dfns:
self.visit(dfn, depth + 2)
@ -345,14 +410,14 @@ class FoldTraitDefVisitor(TypeInfoEmitVisitor):
self.emit("}", depth)
class FoldImplVisitor(TypeInfoEmitVisitor):
class FoldImplVisitor(EmitVisitor):
def visitModule(self, mod, depth):
self.emit(
"fn fold_located<U, F: Fold<U> + ?Sized, T, MT>(folder: &mut F, node: Located<T, U>, f: impl FnOnce(&mut F, T) -> Result<MT, F::Error>) -> Result<Located<MT, F::TargetU>, F::Error> {",
"fn fold_located<U, F: Fold<U> + ?Sized, T, MT>(folder: &mut F, node: Attributed<T, U>, f: impl FnOnce(&mut F, T) -> Result<MT, F::Error>) -> Result<Attributed<MT, F::TargetU>, F::Error> {",
depth,
)
self.emit(
"Ok(Located { custom: folder.map_user(node.custom)?, range: node.range, node: f(folder, node.node)? })",
"let node = folder.map_located(node)?; Ok(Attributed { custom: node.custom, range: node.range, node: f(folder, node.node)? })",
depth + 1,
)
self.emit("}", depth)
@ -363,11 +428,11 @@ class FoldImplVisitor(TypeInfoEmitVisitor):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
typeinfo = self.typeinfo[name]
apply_t, apply_u, apply_target_u = self.get_generics(
name, "T", "U", "F::TargetU"
)
enumname = get_rust_type(name)
is_located = bool(sum.attributes)
self.emit(f"impl<T, U> Foldable<T, U> for {enumname}{apply_t} {{", depth)
self.emit(f"type Mapped = {enumname}{apply_u};", depth + 1)
@ -383,15 +448,13 @@ class FoldImplVisitor(TypeInfoEmitVisitor):
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, F::Error> {{",
depth,
)
if is_located:
if typeinfo.has_attributes:
self.emit("fold_located(folder, node, |folder, node| {", depth)
rustname = enumname + "Kind"
else:
rustname = enumname
self.emit("match node {", depth + 1)
for cons in sum.types:
fields_pattern = self.make_pattern(
enumname, rustname, cons.name, cons.fields
enumname, typeinfo.rust_suffix, cons.name, cons.fields
)
self.emit(
f"{fields_pattern[0]} {{ {fields_pattern[1]} }} {fields_pattern[2]} => {{",
@ -402,7 +465,7 @@ class FoldImplVisitor(TypeInfoEmitVisitor):
)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
if is_located:
if typeinfo.has_attributes:
self.emit("})", depth)
self.emit("}", depth)
@ -411,7 +474,7 @@ class FoldImplVisitor(TypeInfoEmitVisitor):
name, "T", "U", "F::TargetU"
)
structname = get_rust_type(name)
is_located = bool(product.attributes)
has_attributes = bool(product.attributes)
self.emit(f"impl<T, U> Foldable<T, U> for {structname}{apply_t} {{", depth)
self.emit(f"type Mapped = {structname}{apply_u};", depth + 1)
@ -427,7 +490,7 @@ class FoldImplVisitor(TypeInfoEmitVisitor):
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {structname}{apply_u}) -> Result<{structname}{apply_target_u}, F::Error> {{",
depth,
)
if is_located:
if has_attributes:
self.emit("fold_located(folder, node, |folder, node| {", depth)
rustname = structname + "Data"
else:
@ -435,16 +498,16 @@ class FoldImplVisitor(TypeInfoEmitVisitor):
fields_pattern = self.make_pattern(rustname, structname, None, product.fields)
self.emit(f"let {rustname} {{ {fields_pattern[1]} }} = node;", depth + 1)
self.gen_construction(rustname, product.fields, "", depth + 1)
if is_located:
if has_attributes:
self.emit("})", depth)
self.emit("}", depth)
def make_pattern(self, rustname, pyname, fieldname, fields):
def make_pattern(self, rustname, suffix, fieldname, fields):
if fields:
header = f"{pyname}::{fieldname}({rustname}{fieldname}"
header = f"{rustname}{suffix}::{fieldname}({rustname}{fieldname}"
footer = ")"
else:
header = f"{pyname}::{fieldname}"
header = f"{rustname}{suffix}::{fieldname}"
footer = ""
body = ",".join(rust_field(f.name) for f in fields)
@ -458,7 +521,7 @@ class FoldImplVisitor(TypeInfoEmitVisitor):
self.emit(f"}}{footer})", depth)
class FoldModuleVisitor(TypeInfoEmitVisitor):
class FoldModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit('#[cfg(feature = "fold")]', depth)
@ -576,10 +639,10 @@ class TraitImplVisitor(EmitVisitor):
if sum.attributes:
rustname = enumname + "Kind"
self.emit(f"impl NamedNode for ast::{rustname} {{", depth)
self.emit(f"impl NamedNode for ast::located::{rustname} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::{rustname} {{", depth)
self.emit(f"impl Node for ast::located::{rustname} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
@ -597,10 +660,12 @@ class TraitImplVisitor(EmitVisitor):
self.emit("}", depth)
def constructor_to_object(self, cons, enumname, rustname, depth):
self.emit(f"ast::{rustname}::{cons.name}", depth)
self.emit(f"ast::located::{rustname}::{cons.name}", depth)
if cons.fields:
fields_pattern = self.make_pattern(cons.fields)
self.emit(f"( ast::{enumname}{cons.name} {{ {fields_pattern} }} )", depth)
self.emit(
f"( ast::located::{enumname}{cons.name} {{ {fields_pattern} }} )", depth
)
self.emit(" => {", depth)
self.make_node(cons.name, cons.fields, depth + 1)
self.emit("}", depth)
@ -610,15 +675,17 @@ class TraitImplVisitor(EmitVisitor):
if product.attributes:
structname += "Data"
self.emit(f"impl NamedNode for ast::{structname} {{", depth)
self.emit(f"impl NamedNode for ast::located::{structname} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::{structname} {{", depth)
self.emit(f"impl Node for ast::located::{structname} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
fields_pattern = self.make_pattern(product.fields)
self.emit(f"let ast::{structname} {{ {fields_pattern} }} = self;", depth + 2)
self.emit(
f"let ast::located::{structname} {{ {fields_pattern} }} = self;", depth + 2
)
self.make_node(name, product.fields, depth + 2)
self.emit("}", depth + 1)
self.emit(
@ -656,11 +723,14 @@ class TraitImplVisitor(EmitVisitor):
for cons in sum.types:
self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth)
if cons.fields:
self.emit(f"ast::{rustname}::{cons.name} (ast::{enumname}{cons.name} {{", depth + 1)
self.emit(
f"ast::located::{rustname}::{cons.name} (ast::located::{enumname}{cons.name} {{",
depth + 1,
)
self.gen_construction_fields(cons, sumname, depth + 1)
self.emit("})", depth + 1)
else:
self.emit(f"ast::{rustname}::{cons.name}", depth + 1)
self.emit(f"ast::located::{rustname}::{cons.name}", depth + 1)
self.emit("} else", depth)
self.emit("{", depth)
@ -684,14 +754,14 @@ class TraitImplVisitor(EmitVisitor):
)
def gen_construction(self, cons_path, cons, name, depth):
self.emit(f"ast::{cons_path} {{", depth)
self.emit(f"ast::located::{cons_path} {{", depth)
self.gen_construction_fields(cons, name, depth + 1)
self.emit("}", depth)
def extract_location(self, typename, depth):
row = self.decode_field(asdl.Field("int", "lineno"), typename)
column = self.decode_field(asdl.Field("int", "col_offset"), typename)
self.emit(f"let _location = ast::Location::new({row}, {column});", depth)
self.emit(f"let _location = Location::new({row}, {column});", depth)
def decode_field(self, field, typename):
name = json.dumps(field.name)
@ -711,84 +781,18 @@ class ChainOfVisitors:
v.emit("", 0)
def write_ast_def(mod, typeinfo, f):
def write_generic_def(mod, typeinfo, f):
f.write(
textwrap.dedent(
"""
#![allow(clippy::derive_partial_eq_without_eq)]
pub use crate::constant::*;
pub use rustpython_compiler_core::text_size::{TextSize, TextRange};
pub use crate::{Attributed, constant::*};
use rustpython_compiler_core::{text_size::{TextSize, TextRange}};
type Ident = String;
\n
"""
)
)
StructVisitor(f, typeinfo).emit_attrs(0)
f.write(
textwrap.dedent(
"""
pub struct Located<T, U = ()> {
pub range: TextRange,
pub custom: U,
pub node: T,
}
impl<T> Located<T> {
pub fn new(start: TextSize, end: TextSize, node: T) -> Self {
Self { range: TextRange::new(start, end), custom: (), node }
}
/// Creates a new node that spans the position specified by `range`.
pub fn with_range(node: T, range: TextRange) -> Self {
Self {
range,
custom: (),
node,
}
}
/// Returns the absolute start position of the node from the beginning of the document.
#[inline]
pub const fn start(&self) -> TextSize {
self.range.start()
}
/// Returns the node
#[inline]
pub fn node(&self) -> &T {
&self.node
}
/// Consumes self and returns the node.
#[inline]
pub fn into_node(self) -> T {
self.node
}
/// Returns the `range` of the node. The range offsets are absolute to the start of the document.
#[inline]
pub const fn range(&self) -> TextRange {
self.range
}
/// Returns the absolute position at which the node ends in the source document.
#[inline]
pub const fn end(&self) -> TextSize {
self.range.end()
}
}
impl<T, U> std::ops::Deref for Located<T, U> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.node
}
}
\n
""".lstrip()
"""
)
)
@ -796,24 +800,59 @@ def write_ast_def(mod, typeinfo, f):
c.visit(mod)
def write_ast_mod(mod, f):
def write_located_def(typeinfo, f):
f.write(
textwrap.dedent(
"""
#![allow(clippy::all)]
use rustpython_compiler_core::LocationRange;
use super::*;
use crate::common::ascii;
pub type Located<T> = super::generic::Attributed<T, LocationRange>;
"""
)
)
for info in typeinfo.values():
if info.has_userdata:
generics = "::<LocationRange>"
else:
generics = ""
f.write(
f"pub type {info.rust_sum_name} = super::generic::{info.rust_sum_name}{generics};\n"
)
if info.rust_suffix:
if info.rust_suffix == "Data" and not info.has_expr:
generics = ""
f.write(
f"pub type {info.rust_sum_name}{info.rust_suffix} = super::generic::{info.rust_sum_name}{info.rust_suffix}{generics};\n"
)
"""
def write_ast_mod(mod, typeinfo, f):
f.write(
textwrap.dedent(
"""
#![allow(clippy::all)]
use super::*;
use crate::common::ascii;
"""
)
)
c = ChainOfVisitors(ClassDefVisitor(f), TraitImplVisitor(f), ExtendModuleVisitor(f))
c = ChainOfVisitors(
ClassDefVisitor(f, typeinfo),
TraitImplVisitor(f, typeinfo),
ExtendModuleVisitor(f, typeinfo),
)
c.visit(mod)
def main(input_filename, ast_mod_filename, ast_def_filename, dump_module=False):
def main(
input_filename,
generic_filename,
located_filename,
module_filename,
dump_module=False,
):
auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
mod = asdl.parse(input_filename)
if dump_module:
@ -825,22 +864,34 @@ def main(input_filename, ast_mod_filename, ast_def_filename, dump_module=False):
typeinfo = {}
FindUserdataTypesVisitor(typeinfo).visit(mod)
with ast_def_filename.open("w") as def_file, ast_mod_filename.open("w") as mod_file:
def_file.write(auto_gen_msg)
write_ast_def(mod, typeinfo, def_file)
with generic_filename.open("w") as generic_file, located_filename.open(
"w"
) as located_file:
generic_file.write(auto_gen_msg)
write_generic_def(mod, typeinfo, generic_file)
located_file.write(auto_gen_msg)
write_located_def(typeinfo, located_file)
mod_file.write(auto_gen_msg)
write_ast_mod(mod, mod_file)
with module_filename.open("w") as module_file:
module_file.write(auto_gen_msg)
write_ast_mod(mod, typeinfo, module_file)
print(f"{ast_def_filename}, {ast_mod_filename} regenerated.")
print(f"{generic_filename}, {located_filename}, {module_filename} regenerated.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("input_file", type=Path)
parser.add_argument("-M", "--mod-file", type=Path, required=True)
parser.add_argument("-D", "--def-file", type=Path, required=True)
parser.add_argument("-G", "--generic-file", type=Path, required=True)
parser.add_argument("-L", "--located-file", type=Path, required=True)
parser.add_argument("-M", "--module-file", type=Path, required=True)
parser.add_argument("-d", "--dump-module", action="store_true")
args = parser.parse_args()
main(args.input_file, args.mod_file, args.def_file, args.dump_module)
main(
args.input_file,
args.generic_file,
args.located_file,
args.module_file,
args.dump_module,
)