mirror of
https://github.com/RustPython/Parser.git
synced 2025-08-27 13:54:55 +00:00
New Arguments and Arg/ArgWithDefault AST representation (#59)
This commit is contained in:
parent
3fbf4f6804
commit
fdec727f80
51 changed files with 22648 additions and 21711 deletions
376
ast/asdl_rs.py
376
ast/asdl_rs.py
|
@ -45,6 +45,74 @@ RUST_KEYWORDS = {
|
|||
"type",
|
||||
}
|
||||
|
||||
attributes = [
|
||||
asdl.Field("int", "lineno"),
|
||||
asdl.Field("int", "col_offset"),
|
||||
asdl.Field("int", "end_lineno"),
|
||||
asdl.Field("int", "end_col_offset"),
|
||||
]
|
||||
|
||||
ORIGINAL_NODE_WARNING = "NOTE: This type is different from original Python AST."
|
||||
|
||||
arg_with_default = asdl.Type(
|
||||
"arg_with_default",
|
||||
asdl.Product(
|
||||
[
|
||||
asdl.Field("arg", "def"),
|
||||
asdl.Field(
|
||||
"expr", "default", opt=True
|
||||
), # order is important for cost-free borrow!
|
||||
],
|
||||
),
|
||||
)
|
||||
arg_with_default.doc = f"""
|
||||
An alternative type of AST `arg`. This is used for each function argument that might have a default value.
|
||||
Used by `Arguments` original type.
|
||||
|
||||
{ORIGINAL_NODE_WARNING}
|
||||
""".strip()
|
||||
|
||||
alt_arguments = asdl.Type(
|
||||
"alt:arguments",
|
||||
asdl.Product(
|
||||
[
|
||||
asdl.Field("arg_with_default", "posonlyargs", seq=True),
|
||||
asdl.Field("arg_with_default", "args", seq=True),
|
||||
asdl.Field("arg", "vararg", opt=True),
|
||||
asdl.Field("arg_with_default", "kwonlyargs", seq=True),
|
||||
asdl.Field("arg", "kwarg", opt=True),
|
||||
]
|
||||
),
|
||||
)
|
||||
alt_arguments.doc = f"""
|
||||
An alternative type of AST `arguments`. This is parser-friendly and human-friendly definition of function arguments.
|
||||
This form also has advantage to implement pre-order traverse.
|
||||
`defaults` and `kw_defaults` fields are removed and the default values are placed under each `arg_with_default` typed argument.
|
||||
`vararg` and `kwarg` are still typed as `arg` because they never can have a default value.
|
||||
|
||||
The matching Python style AST type is [PythonArguments]. While [PythonArguments] has ordered `kwonlyargs` fields by
|
||||
default existence, [Arguments] has location-ordered kwonlyargs fields.
|
||||
|
||||
{ORIGINAL_NODE_WARNING}
|
||||
""".strip()
|
||||
|
||||
# Must be used only for rust types, not python types
|
||||
CUSTOM_TYPES = [
|
||||
alt_arguments,
|
||||
arg_with_default,
|
||||
]
|
||||
|
||||
CUSTOM_REPLACEMENTS = {
|
||||
"arguments": alt_arguments,
|
||||
}
|
||||
CUSTOM_ATTACHMENTS = [
|
||||
arg_with_default,
|
||||
]
|
||||
|
||||
|
||||
def maybe_custom(type):
|
||||
return CUSTOM_REPLACEMENTS.get(type.name, type)
|
||||
|
||||
|
||||
def rust_field_name(name):
|
||||
name = rust_type_name(name)
|
||||
|
@ -137,6 +205,20 @@ class TypeInfo:
|
|||
f.type != "identifier" for f in self.type.value.fields
|
||||
)
|
||||
|
||||
@property
|
||||
def is_custom(self):
|
||||
return self.type.name in [t.name for t in CUSTOM_TYPES]
|
||||
|
||||
@property
|
||||
def is_custom_replaced(self):
|
||||
return self.type.name in CUSTOM_REPLACEMENTS
|
||||
|
||||
@property
|
||||
def custom(self):
|
||||
if self.type.name in CUSTOM_REPLACEMENTS:
|
||||
return CUSTOM_REPLACEMENTS[self.type.name]
|
||||
return self.type
|
||||
|
||||
def no_cfg(self, typeinfo):
|
||||
if self.is_product:
|
||||
return self.has_attributes
|
||||
|
@ -150,20 +232,26 @@ class TypeInfo:
|
|||
return rust_type_name(self.name)
|
||||
|
||||
@property
|
||||
def sum_name(self):
|
||||
def full_field_name(self):
|
||||
name = self.name
|
||||
if name.startswith("alt:"):
|
||||
name = name[4:]
|
||||
if self.enum_name is None:
|
||||
return self.name
|
||||
return name
|
||||
else:
|
||||
return f"{self.enum_name}_{self.name}"
|
||||
return f"{self.enum_name}_{rust_field_name(name)}"
|
||||
|
||||
@property
|
||||
def rust_sum_name(self):
|
||||
rust_name = rust_type_name(self.name)
|
||||
if self.enum_name is None:
|
||||
return rust_name
|
||||
else:
|
||||
name = rust_type_name(self.enum_name) + rust_name
|
||||
return name
|
||||
def full_type_name(self):
|
||||
name = self.name
|
||||
if name.startswith("alt:"):
|
||||
name = name[4:]
|
||||
rust_name = rust_type_name(name)
|
||||
if self.enum_name is not None:
|
||||
rust_name = rust_type_name(self.enum_name) + rust_name
|
||||
if self.is_custom_replaced:
|
||||
rust_name = "Python" + rust_name
|
||||
return rust_name
|
||||
|
||||
def determine_user_data(self, type_info, stack):
|
||||
if self.name in stack:
|
||||
|
@ -184,6 +272,10 @@ class TypeInfo:
|
|||
class TypeInfoMixin:
|
||||
type_info: Dict[str, TypeInfo]
|
||||
|
||||
def customized_type_info(self, type_name):
|
||||
info = self.type_info[type_name]
|
||||
return self.type_info[info.custom.name]
|
||||
|
||||
def has_user_data(self, typ):
|
||||
return self.type_info[typ].has_user_data
|
||||
|
||||
|
@ -223,18 +315,19 @@ class FindUserDataTypesVisitor(asdl.VisitorBase):
|
|||
super().__init__()
|
||||
|
||||
def visitModule(self, mod):
|
||||
for dfn in mod.dfns:
|
||||
for dfn in mod.dfns + CUSTOM_TYPES:
|
||||
self.visit(dfn)
|
||||
stack = set()
|
||||
for info in self.type_info.values():
|
||||
info.determine_user_data(self.type_info, stack)
|
||||
|
||||
def visitType(self, type):
|
||||
self.type_info[type.name] = TypeInfo(type)
|
||||
self.visit(type.value, type)
|
||||
key = type.name
|
||||
info = self.type_info[key] = TypeInfo(type)
|
||||
self.visit(type.value, info)
|
||||
|
||||
def visitSum(self, sum, type):
|
||||
info = self.type_info[type.name]
|
||||
def visitSum(self, sum, info):
|
||||
type = info.type
|
||||
info.is_simple = is_simple(sum)
|
||||
for cons in sum.types:
|
||||
self.visit(cons, type, info.is_simple)
|
||||
|
@ -261,8 +354,8 @@ class FindUserDataTypesVisitor(asdl.VisitorBase):
|
|||
info.enum_name = type.name
|
||||
info.is_simple = simple
|
||||
|
||||
def visitProduct(self, product, type):
|
||||
info = self.type_info[type.name]
|
||||
def visitProduct(self, product, info):
|
||||
type = info.type
|
||||
if product.attributes:
|
||||
# attributes means located, which has the `range: R` field
|
||||
info.has_user_data = True
|
||||
|
@ -308,7 +401,9 @@ class StructVisitor(EmitVisitor):
|
|||
0,
|
||||
)
|
||||
for dfn in mod.dfns:
|
||||
rust_name = rust_type_name(dfn.name)
|
||||
info = self.customized_type_info(dfn.name)
|
||||
dfn = info.custom
|
||||
rust_name = info.full_type_name
|
||||
generics = "" if self.type_info[dfn.name].is_simple else "<R>"
|
||||
if dfn.name == "mod":
|
||||
# This is exceptional rule to other enums.
|
||||
|
@ -329,7 +424,8 @@ class StructVisitor(EmitVisitor):
|
|||
0,
|
||||
)
|
||||
for dfn in mod.dfns:
|
||||
rust_name = rust_type_name(dfn.name)
|
||||
info = self.customized_type_info(dfn.name)
|
||||
rust_name = info.full_type_name
|
||||
generics = "" if self.type_info[dfn.name].is_simple else "<R>"
|
||||
self.emit(
|
||||
f"""
|
||||
|
@ -342,10 +438,13 @@ class StructVisitor(EmitVisitor):
|
|||
0,
|
||||
)
|
||||
|
||||
for dfn in mod.dfns:
|
||||
for dfn in mod.dfns + CUSTOM_TYPES:
|
||||
self.visit(dfn)
|
||||
|
||||
def visitType(self, type, depth=0):
|
||||
if hasattr(type, "doc"):
|
||||
doc = "/// " + type.doc.replace("\n", "\n/// ") + "\n"
|
||||
self.emit(doc, depth)
|
||||
self.visit(type.value, type, depth)
|
||||
|
||||
def visitSum(self, sum, type, depth):
|
||||
|
@ -492,8 +591,12 @@ class StructVisitor(EmitVisitor):
|
|||
self.emit(f"{cons.name},", depth)
|
||||
|
||||
def visitField(self, field, parent, vis, depth, constructor=None):
|
||||
typ = rust_type_name(field.type)
|
||||
field_type = self.type_info.get(field.type)
|
||||
try:
|
||||
field_type = self.customized_type_info(field.type)
|
||||
typ = field_type.full_type_name
|
||||
except KeyError:
|
||||
field_type = None
|
||||
typ = rust_type_name(field.type)
|
||||
if field_type and not field_type.is_simple:
|
||||
typ = f"{typ}<R>"
|
||||
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
|
||||
|
@ -520,9 +623,8 @@ class StructVisitor(EmitVisitor):
|
|||
|
||||
def visitProduct(self, product, type, depth):
|
||||
type_info = self.type_info[type.name]
|
||||
product_name = rust_type_name(type.name)
|
||||
product_name = type_info.full_type_name
|
||||
self.emit_attrs(depth)
|
||||
|
||||
self.emit(f"pub struct {product_name}<R = TextRange> {{", depth)
|
||||
self.emit_range(product.attributes, depth + 1)
|
||||
for f in product.fields:
|
||||
|
@ -584,19 +686,20 @@ class FoldTraitDefVisitor(EmitVisitor):
|
|||
}""",
|
||||
depth + 1,
|
||||
)
|
||||
for dfn in mod.dfns:
|
||||
for dfn in mod.dfns + [arg_with_default]:
|
||||
dfn = maybe_custom(dfn)
|
||||
self.visit(dfn, depth + 2)
|
||||
self.emit("}", depth)
|
||||
|
||||
def visitType(self, type, depth):
|
||||
name = type.name
|
||||
apply_u, apply_target_u = self.apply_generics(name, "U", "Self::TargetU")
|
||||
enum_name = rust_type_name(name)
|
||||
info = self.type_info[type.name]
|
||||
apply_u, apply_target_u = self.apply_generics(info.name, "U", "Self::TargetU")
|
||||
enum_name = info.full_type_name
|
||||
self.emit(
|
||||
f"fn fold_{name}(&mut self, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, Self::Error> {{",
|
||||
f"fn fold_{info.full_field_name}(&mut self, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, Self::Error> {{",
|
||||
depth,
|
||||
)
|
||||
self.emit(f"fold_{name}(self, node)", depth + 1)
|
||||
self.emit(f"fold_{info.full_field_name}(self, node)", depth + 1)
|
||||
self.emit("}", depth)
|
||||
|
||||
if isinstance(type.value, asdl.Sum) and not is_simple(type.value):
|
||||
|
@ -604,9 +707,10 @@ class FoldTraitDefVisitor(EmitVisitor):
|
|||
self.visit(cons, type, depth)
|
||||
|
||||
def visitConstructor(self, cons, type, depth):
|
||||
info = self.type_info[type.name]
|
||||
apply_u, apply_target_u = self.apply_generics(type.name, "U", "Self::TargetU")
|
||||
enum_name = rust_type_name(type.name)
|
||||
func_name = f"fold_{type.name}_{rust_field_name(cons.name)}"
|
||||
func_name = f"fold_{info.full_field_name}_{rust_field_name(cons.name)}"
|
||||
self.emit(
|
||||
f"fn {func_name}(&mut self, node: {enum_name}{cons.name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, Self::Error> {{",
|
||||
depth,
|
||||
|
@ -617,7 +721,8 @@ class FoldTraitDefVisitor(EmitVisitor):
|
|||
|
||||
class FoldImplVisitor(EmitVisitor):
|
||||
def visitModule(self, mod, depth):
|
||||
for dfn in mod.dfns:
|
||||
for dfn in mod.dfns + [arg_with_default]:
|
||||
dfn = maybe_custom(dfn)
|
||||
self.visit(dfn, depth)
|
||||
|
||||
def visitType(self, type, depth=0):
|
||||
|
@ -668,7 +773,8 @@ class FoldImplVisitor(EmitVisitor):
|
|||
apply_t, apply_u, apply_target_u = self.apply_generics(
|
||||
type.name, "T", "U", "F::TargetU"
|
||||
)
|
||||
enum_name = rust_type_name(type.name)
|
||||
info = self.type_info[type.name]
|
||||
enum_name = info.full_type_name
|
||||
|
||||
cons_type_name = f"{enum_name}{cons.name}"
|
||||
|
||||
|
@ -679,21 +785,20 @@ class FoldImplVisitor(EmitVisitor):
|
|||
depth + 1,
|
||||
)
|
||||
self.emit(
|
||||
f"folder.fold_{type.name}_{rust_field_name(cons.name)}(self)", depth + 2
|
||||
f"folder.fold_{info.full_field_name}_{rust_field_name(cons.name)}(self)",
|
||||
depth + 2,
|
||||
)
|
||||
self.emit("}", depth + 1)
|
||||
self.emit("}", depth)
|
||||
|
||||
self.emit(
|
||||
f"pub fn fold_{type.name}_{rust_field_name(cons.name)}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {cons_type_name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, F::Error> {{",
|
||||
f"pub fn fold_{info.full_field_name}_{rust_field_name(cons.name)}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {cons_type_name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, F::Error> {{",
|
||||
depth,
|
||||
)
|
||||
|
||||
type_info = self.type_info[type.name]
|
||||
|
||||
fields_pattern = self.make_pattern(cons.fields)
|
||||
|
||||
map_user_suffix = "" if type_info.has_attributes else "_cfg"
|
||||
map_user_suffix = "" if info.has_attributes else "_cfg"
|
||||
self.emit(
|
||||
f"""
|
||||
let {cons_type_name} {{ {fields_pattern} }} = node;
|
||||
|
@ -710,11 +815,12 @@ class FoldImplVisitor(EmitVisitor):
|
|||
self.emit("}", depth + 2)
|
||||
|
||||
def visitProduct(self, product, type, depth):
|
||||
info = self.type_info[type.name]
|
||||
name = type.name
|
||||
apply_t, apply_u, apply_target_u = self.apply_generics(
|
||||
name, "T", "U", "F::TargetU"
|
||||
)
|
||||
struct_name = rust_type_name(name)
|
||||
struct_name = info.full_type_name
|
||||
has_attributes = bool(product.attributes)
|
||||
|
||||
self.emit(f"impl<T, U> Foldable<T, U> for {struct_name}{apply_t} {{", depth)
|
||||
|
@ -723,12 +829,12 @@ class FoldImplVisitor(EmitVisitor):
|
|||
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
|
||||
depth + 1,
|
||||
)
|
||||
self.emit(f"folder.fold_{name}(self)", depth + 2)
|
||||
self.emit(f"folder.fold_{info.full_field_name}(self)", depth + 2)
|
||||
self.emit("}", depth + 1)
|
||||
self.emit("}", depth)
|
||||
|
||||
self.emit(
|
||||
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {struct_name}{apply_u}) -> Result<{struct_name}{apply_target_u}, F::Error> {{",
|
||||
f"pub fn fold_{info.full_field_name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {struct_name}{apply_u}) -> Result<{struct_name}{apply_target_u}, F::Error> {{",
|
||||
depth,
|
||||
)
|
||||
|
||||
|
@ -777,7 +883,7 @@ class FoldModuleVisitor(EmitVisitor):
|
|||
FoldImplVisitor(self.file, self.type_info).visit(mod, depth)
|
||||
|
||||
|
||||
class VisitorTraitDefVisitor(StructVisitor):
|
||||
class VisitorModuleVisitor(StructVisitor):
|
||||
def full_name(self, name):
|
||||
type_info = self.type_info[name]
|
||||
if type_info.enum_name:
|
||||
|
@ -792,10 +898,12 @@ class VisitorTraitDefVisitor(StructVisitor):
|
|||
else:
|
||||
return rust_type_name(name)
|
||||
|
||||
def visitModule(self, mod, depth):
|
||||
def visitModule(self, mod, depth=0):
|
||||
self.emit("#[allow(unused_variables)]", depth)
|
||||
self.emit("pub trait Visitor<R=crate::text_size::TextRange> {", depth)
|
||||
|
||||
for dfn in mod.dfns:
|
||||
dfn = self.customized_type_info(dfn.name).type
|
||||
self.visit(dfn, depth + 1)
|
||||
self.emit("}", depth)
|
||||
|
||||
|
@ -810,26 +918,28 @@ class VisitorTraitDefVisitor(StructVisitor):
|
|||
|
||||
def emit_visitor(self, nodename, depth, has_node=True):
|
||||
type_info = self.type_info[nodename]
|
||||
node_type = type_info.rust_sum_name
|
||||
node_type = type_info.full_type_name
|
||||
(generic,) = self.apply_generics(nodename, "R")
|
||||
self.emit(
|
||||
f"fn visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{",
|
||||
f"fn visit_{type_info.full_field_name}(&mut self, node: {node_type}{generic}) {{",
|
||||
depth,
|
||||
)
|
||||
if has_node:
|
||||
self.emit(f"self.generic_visit_{type_info.sum_name}(node)", depth + 1)
|
||||
self.emit(
|
||||
f"self.generic_visit_{type_info.full_field_name}(node)", depth + 1
|
||||
)
|
||||
|
||||
self.emit("}", depth)
|
||||
|
||||
def emit_generic_visitor_signature(self, nodename, depth, has_node=True):
|
||||
type_info = self.type_info[nodename]
|
||||
if has_node:
|
||||
node_type = type_info.rust_sum_name
|
||||
node_type = type_info.full_type_name
|
||||
else:
|
||||
node_type = "()"
|
||||
(generic,) = self.apply_generics(nodename, "R")
|
||||
self.emit(
|
||||
f"fn generic_visit_{type_info.sum_name}(&mut self, node: {node_type}{generic}) {{",
|
||||
f"fn generic_visit_{type_info.full_field_name}(&mut self, node: {node_type}{generic}) {{",
|
||||
depth,
|
||||
)
|
||||
|
||||
|
@ -844,7 +954,9 @@ class VisitorTraitDefVisitor(StructVisitor):
|
|||
def visit_match_for_type(self, nodename, rust_name, type_, depth):
|
||||
self.emit(f"{rust_name}::{type_.name}", depth)
|
||||
self.emit("(data)", depth)
|
||||
self.emit(f"=> self.visit_{nodename}_{type_.name}(data),", depth)
|
||||
self.emit(
|
||||
f"=> self.visit_{nodename}_{rust_field_name(type_.name)}(data),", depth
|
||||
)
|
||||
|
||||
def visit_sum_type(self, name, type_, depth):
|
||||
self.emit_visitor(type_.name, depth, has_node=type_.fields)
|
||||
|
@ -852,28 +964,32 @@ class VisitorTraitDefVisitor(StructVisitor):
|
|||
return
|
||||
|
||||
self.emit_generic_visitor_signature(type_.name, depth, has_node=True)
|
||||
for f in type_.fields:
|
||||
fieldname = rust_field(f.name)
|
||||
field_type = self.type_info.get(f.type)
|
||||
for field in type_.fields:
|
||||
if field.type in CUSTOM_REPLACEMENTS:
|
||||
type_name = CUSTOM_REPLACEMENTS[field.type].name
|
||||
else:
|
||||
type_name = field.type
|
||||
field_name = rust_field(field.name)
|
||||
field_type = self.type_info.get(type_name)
|
||||
if not (field_type and field_type.has_user_data):
|
||||
continue
|
||||
|
||||
if f.opt:
|
||||
self.emit(f"if let Some(value) = node.{fieldname} {{", depth + 1)
|
||||
elif f.seq:
|
||||
iterable = f"node.{fieldname}"
|
||||
if type_.name == "Dict" and f.name == "keys":
|
||||
if field.opt:
|
||||
self.emit(f"if let Some(value) = node.{field_name} {{", depth + 1)
|
||||
elif field.seq:
|
||||
iterable = f"node.{field_name}"
|
||||
if type_.name == "Dict" and field.name == "keys":
|
||||
iterable = f"{iterable}.into_iter().flatten()"
|
||||
self.emit(f"for value in {iterable} {{", depth + 1)
|
||||
else:
|
||||
self.emit("{", depth + 1)
|
||||
self.emit(f"let value = node.{fieldname};", depth + 2)
|
||||
self.emit(f"let value = node.{field_name};", depth + 2)
|
||||
|
||||
variable = "value"
|
||||
if field_type.boxed and (not f.seq or f.opt):
|
||||
if field_type.boxed and (not field.seq or field.opt):
|
||||
variable = "*" + variable
|
||||
type_info = self.type_info[field_type.name]
|
||||
self.emit(f"self.visit_{type_info.sum_name}({variable});", depth + 2)
|
||||
self.emit(f"self.visit_{type_info.full_field_name}({variable});", depth + 2)
|
||||
|
||||
self.emit("}", depth + 1)
|
||||
|
||||
|
@ -903,16 +1019,9 @@ class VisitorTraitDefVisitor(StructVisitor):
|
|||
self.emit_empty_generic_visitor(name, depth)
|
||||
|
||||
|
||||
class VisitorModuleVisitor(EmitVisitor):
|
||||
def visitModule(self, mod):
|
||||
depth = 0
|
||||
self.emit("#[allow(unused_variables, non_snake_case)]", depth)
|
||||
VisitorTraitDefVisitor(self.file, self.type_info).visit(mod, depth)
|
||||
|
||||
|
||||
class RangedDefVisitor(EmitVisitor):
|
||||
def visitModule(self, mod):
|
||||
for dfn in mod.dfns:
|
||||
for dfn in mod.dfns + CUSTOM_TYPES:
|
||||
self.visit(dfn)
|
||||
|
||||
def visitType(self, type, depth=0):
|
||||
|
@ -944,7 +1053,7 @@ class RangedDefVisitor(EmitVisitor):
|
|||
|
||||
self.emit(
|
||||
f"""
|
||||
impl Ranged for crate::{info.rust_sum_name} {{
|
||||
impl Ranged for crate::{info.full_type_name} {{
|
||||
fn range(&self) -> TextRange {{
|
||||
match self {{
|
||||
{sum_match_arms}
|
||||
|
@ -966,7 +1075,7 @@ class RangedDefVisitor(EmitVisitor):
|
|||
generics = "" if info.is_simple else "::<TextRange>"
|
||||
|
||||
self.emit(
|
||||
f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};",
|
||||
f"pub type {info.full_type_name} = crate::generic::{info.full_type_name}{generics};",
|
||||
0,
|
||||
)
|
||||
self.emit("", 0)
|
||||
|
@ -977,7 +1086,7 @@ class RangedDefVisitor(EmitVisitor):
|
|||
|
||||
self.file.write(
|
||||
f"""
|
||||
impl Ranged for crate::generic::{info.rust_sum_name}::<TextRange> {{
|
||||
impl Ranged for crate::generic::{info.full_type_name}::<TextRange> {{
|
||||
fn range(&self) -> TextRange {{
|
||||
self.range
|
||||
}}
|
||||
|
@ -988,7 +1097,7 @@ class RangedDefVisitor(EmitVisitor):
|
|||
|
||||
class LocatedDefVisitor(EmitVisitor):
|
||||
def visitModule(self, mod):
|
||||
for dfn in mod.dfns:
|
||||
for dfn in mod.dfns + CUSTOM_TYPES:
|
||||
self.visit(dfn)
|
||||
|
||||
def visitType(self, type, depth=0):
|
||||
|
@ -1020,7 +1129,7 @@ class LocatedDefVisitor(EmitVisitor):
|
|||
|
||||
self.emit(
|
||||
f"""
|
||||
impl Located for {info.rust_sum_name} {{
|
||||
impl Located for {info.full_type_name} {{
|
||||
fn range(&self) -> SourceRange {{
|
||||
match self {{
|
||||
{sum_match_arms}
|
||||
|
@ -1041,7 +1150,7 @@ class LocatedDefVisitor(EmitVisitor):
|
|||
generics = "" if info.is_simple else "::<SourceRange>"
|
||||
|
||||
self.emit(
|
||||
f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};",
|
||||
f"pub type {info.full_type_name} = crate::generic::{info.full_type_name}{generics};",
|
||||
0,
|
||||
)
|
||||
self.emit("", 0)
|
||||
|
@ -1052,7 +1161,7 @@ class LocatedDefVisitor(EmitVisitor):
|
|||
|
||||
self.emit(
|
||||
f"""
|
||||
impl Located for {info.rust_sum_name} {{
|
||||
impl Located for {info.full_type_name} {{
|
||||
fn range(&self) -> SourceRange {{
|
||||
self.range
|
||||
}}
|
||||
|
@ -1086,11 +1195,13 @@ class ToPyo3AstVisitor(EmitVisitor):
|
|||
self.visit(type.value, type)
|
||||
|
||||
def visitProduct(self, product, type):
|
||||
rust_name = rust_type_name(type.name)
|
||||
info = self.type_info[type.name]
|
||||
rust_name = info.full_type_name
|
||||
self.emit_to_pyo3_with_fields(product, type, rust_name)
|
||||
|
||||
def visitSum(self, sum, type):
|
||||
rust_name = rust_type_name(type.name)
|
||||
info = self.type_info[type.name]
|
||||
rust_name = info.full_type_name
|
||||
simple = is_simple(sum)
|
||||
if is_simple(sum):
|
||||
return
|
||||
|
@ -1243,8 +1354,9 @@ class Pyo3StructVisitor(EmitVisitor):
|
|||
def ref(self):
|
||||
return "&" if self.borrow else ""
|
||||
|
||||
def emit_class(self, name, rust_name, simple, base="super::Ast"):
|
||||
info = self.type_info[name]
|
||||
def emit_class(self, info, simple, base="super::Ast"):
|
||||
inner_name = info.full_type_name
|
||||
rust_name = self.type_info[info.custom.name].full_type_name
|
||||
if simple:
|
||||
generics = ""
|
||||
else:
|
||||
|
@ -1255,23 +1367,24 @@ class Pyo3StructVisitor(EmitVisitor):
|
|||
into = f"{rust_name}"
|
||||
else:
|
||||
subclass = ""
|
||||
body = f"(pub {self.ref_def} ast::{rust_name}{generics})"
|
||||
body = f"(pub {self.ref_def} ast::{inner_name}{generics})"
|
||||
into = f"{rust_name}(node)"
|
||||
|
||||
self.emit(
|
||||
f"""
|
||||
#[pyclass(module="{self.module_name}", name="_{name}", extends={base}, frozen{subclass})]
|
||||
#[pyclass(module="{self.module_name}", name="_{info.name}", extends={base}, frozen{subclass})]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct {rust_name} {body};
|
||||
|
||||
impl From<{self.ref_def} ast::{rust_name}{generics}> for {rust_name} {{
|
||||
fn from({"" if body else "_"}node: {self.ref_def} ast::{rust_name}{generics}) -> Self {{
|
||||
impl From<{self.ref_def} ast::{inner_name}{generics}> for {rust_name} {{
|
||||
fn from({"" if body else "_"}node: {self.ref_def} ast::{inner_name}{generics}) -> Self {{
|
||||
{into}
|
||||
}}
|
||||
}}
|
||||
""",
|
||||
0,
|
||||
)
|
||||
|
||||
if subclass:
|
||||
self.emit(
|
||||
f"""
|
||||
|
@ -1313,7 +1426,7 @@ class Pyo3StructVisitor(EmitVisitor):
|
|||
)
|
||||
|
||||
if not subclass:
|
||||
self.emit_wrapper(rust_name)
|
||||
self.emit_wrapper(info)
|
||||
|
||||
def emit_getter(self, owner, type_name):
|
||||
self.emit(
|
||||
|
@ -1371,10 +1484,12 @@ class Pyo3StructVisitor(EmitVisitor):
|
|||
0,
|
||||
)
|
||||
|
||||
def emit_wrapper(self, rust_name):
|
||||
def emit_wrapper(self, info):
|
||||
inner_name = info.full_type_name
|
||||
rust_name = self.type_info[info.custom.name].full_type_name
|
||||
self.emit(
|
||||
f"""
|
||||
impl ToPyWrapper for ast::{rust_name}{self.generics} {{
|
||||
impl ToPyWrapper for ast::{inner_name}{self.generics} {{
|
||||
#[inline]
|
||||
fn to_py_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
|
||||
Ok({rust_name}(self).to_object(py))
|
||||
|
@ -1392,10 +1507,11 @@ class Pyo3StructVisitor(EmitVisitor):
|
|||
self.visit(type.value, type, depth)
|
||||
|
||||
def visitSum(self, sum, type, depth=0):
|
||||
info = self.type_info[type.name]
|
||||
rust_name = rust_type_name(type.name)
|
||||
|
||||
simple = is_simple(sum)
|
||||
self.emit_class(type.name, rust_name, simple)
|
||||
self.emit_class(info, simple)
|
||||
|
||||
if not simple:
|
||||
self.emit(
|
||||
|
@ -1424,8 +1540,9 @@ class Pyo3StructVisitor(EmitVisitor):
|
|||
self.visit(cons, rust_name, simple, depth + 1)
|
||||
|
||||
def visitProduct(self, product, type, depth=0):
|
||||
info = self.type_info[type.name]
|
||||
rust_name = rust_type_name(type.name)
|
||||
self.emit_class(type.name, rust_name, False)
|
||||
self.emit_class(info, False)
|
||||
if self.borrow:
|
||||
self.emit_getter(product, rust_name)
|
||||
|
||||
|
@ -1448,9 +1565,9 @@ class Pyo3StructVisitor(EmitVisitor):
|
|||
depth,
|
||||
)
|
||||
else:
|
||||
info = self.type_info[cons.name]
|
||||
self.emit_class(
|
||||
cons.name,
|
||||
f"{parent}{cons.name}",
|
||||
info,
|
||||
simple=False,
|
||||
base=parent,
|
||||
)
|
||||
|
@ -1471,23 +1588,25 @@ class Pyo3PymoduleVisitor(EmitVisitor):
|
|||
self.visit(type.value, type.name, depth)
|
||||
|
||||
def visitProduct(self, product, name, depth=0):
|
||||
rust_name = rust_type_name(name)
|
||||
self.emit_fields(name, rust_name, False)
|
||||
info = self.type_info[name]
|
||||
self.emit_fields(info, False)
|
||||
|
||||
def visitSum(self, sum, name, depth):
|
||||
rust_name = rust_type_name(name)
|
||||
info = self.type_info[name]
|
||||
simple = is_simple(sum)
|
||||
self.emit_fields(name, rust_name, True)
|
||||
self.emit_fields(info, True)
|
||||
|
||||
for cons in sum.types:
|
||||
self.visit(cons, name, simple, depth)
|
||||
|
||||
def visitConstructor(self, cons, parent, simple, depth):
|
||||
rust_name = rust_type_name(parent) + rust_type_name(cons.name)
|
||||
self.emit_fields(cons.name, rust_name, simple)
|
||||
info = self.type_info[cons.name]
|
||||
self.emit_fields(info, simple)
|
||||
|
||||
def emit_fields(self, name, rust_name, simple):
|
||||
self.emit(f"super::init_type::<{rust_name}, ast::{rust_name}>(py, m)?;", 1)
|
||||
def emit_fields(self, info, simple):
|
||||
inner_name = info.full_type_name
|
||||
rust_name = self.type_info[info.custom.name].full_type_name
|
||||
self.emit(f"super::init_type::<{rust_name}, ast::{inner_name}>(py, m)?;", 1)
|
||||
|
||||
|
||||
class StdlibClassDefVisitor(EmitVisitor):
|
||||
|
@ -1499,7 +1618,9 @@ class StdlibClassDefVisitor(EmitVisitor):
|
|||
self.visit(type.value, type.name, depth)
|
||||
|
||||
def visitSum(self, sum, name, depth):
|
||||
struct_name = "Node" + rust_type_name(name)
|
||||
# info = self.type_info[self.type_info[name].custom.name]
|
||||
info = self.type_info[name]
|
||||
struct_name = "Node" + info.full_type_name
|
||||
self.emit(
|
||||
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "NodeAst")]',
|
||||
depth,
|
||||
|
@ -1517,11 +1638,13 @@ class StdlibClassDefVisitor(EmitVisitor):
|
|||
self.gen_class_def(name, product.fields, product.attributes, depth)
|
||||
|
||||
def gen_class_def(self, name, fields, attrs, depth, base=None):
|
||||
|
||||
info = self.type_info[self.type_info[name].custom.name]
|
||||
if base is None:
|
||||
base = "NodeAst"
|
||||
struct_name = "Node" + rust_type_name(name)
|
||||
struct_name = "Node" + info.full_type_name
|
||||
else:
|
||||
struct_name = base + rust_type_name(name)
|
||||
struct_name = "Node" + info.full_type_name
|
||||
self.emit(
|
||||
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = {json.dumps(base)})]',
|
||||
depth,
|
||||
|
@ -1596,11 +1719,9 @@ class StdlibTraitImplVisitor(EmitVisitor):
|
|||
self.visit(type.value, type.name, depth)
|
||||
|
||||
def visitSum(self, sum, name, depth):
|
||||
rust_name = rust_type_name(name)
|
||||
info = self.type_info[name]
|
||||
rust_name = info.full_type_name
|
||||
|
||||
self.emit(f"impl NamedNode for ast::located::{rust_name} {{", depth)
|
||||
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
|
||||
self.emit("}", depth)
|
||||
self.emit("// sum", depth)
|
||||
self.emit(f"impl Node for ast::located::{rust_name} {{", depth)
|
||||
self.emit(
|
||||
|
@ -1644,11 +1765,6 @@ class StdlibTraitImplVisitor(EmitVisitor):
|
|||
def visitConstructor(self, cons, sum, sum_rust_name, depth):
|
||||
rust_name = rust_type_name(cons.name)
|
||||
self.emit("// constructor", depth)
|
||||
self.emit(
|
||||
f"impl NamedNode for ast::located::{sum_rust_name}{rust_name} {{", depth
|
||||
)
|
||||
self.emit(f"const NAME: &'static str = {json.dumps(cons.name)};", depth + 1)
|
||||
self.emit("}", depth)
|
||||
self.emit(f"impl Node for ast::located::{sum_rust_name}{rust_name} {{", depth)
|
||||
|
||||
fields_pattern = self.make_pattern(cons.fields)
|
||||
|
@ -1677,12 +1793,10 @@ class StdlibTraitImplVisitor(EmitVisitor):
|
|||
self.emit("}", depth + 1)
|
||||
|
||||
def visitProduct(self, product, name, depth):
|
||||
struct_name = rust_type_name(name)
|
||||
info = self.type_info[name]
|
||||
struct_name = info.full_type_name
|
||||
|
||||
self.emit("// product", depth)
|
||||
self.emit(f"impl NamedNode for ast::located::{struct_name} {{", depth)
|
||||
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
|
||||
self.emit("}", depth)
|
||||
self.emit(f"impl Node for ast::located::{struct_name} {{", depth)
|
||||
self.emit(
|
||||
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
|
||||
|
@ -1826,8 +1940,7 @@ def write_located_def(mod, type_info, f):
|
|||
|
||||
|
||||
def write_pyo3_node(type_info, f):
|
||||
def write(info: TypeInfo):
|
||||
rust_name = info.rust_sum_name
|
||||
def write(info: TypeInfo, rust_name: str):
|
||||
if info.is_simple:
|
||||
generics = ""
|
||||
else:
|
||||
|
@ -1845,8 +1958,14 @@ def write_pyo3_node(type_info, f):
|
|||
""",
|
||||
)
|
||||
|
||||
for info in type_info.values():
|
||||
write(info)
|
||||
for type_name, info in type_info.items():
|
||||
rust_name = info.full_type_name
|
||||
if info.is_custom:
|
||||
if type_name != info.type.name:
|
||||
rust_name = "Python" + rust_name
|
||||
else:
|
||||
continue
|
||||
write(info, rust_name)
|
||||
|
||||
|
||||
def write_to_pyo3(mod, type_info, f):
|
||||
|
@ -1864,7 +1983,9 @@ def write_to_pyo3(mod, type_info, f):
|
|||
)
|
||||
|
||||
for info in type_info.values():
|
||||
rust_name = info.rust_sum_name
|
||||
if info.is_custom:
|
||||
continue
|
||||
rust_name = info.full_type_name
|
||||
f.write(f"cache_py_type::<ast::{rust_name}>(ast_module)?;\n")
|
||||
f.write("Ok(())\n}")
|
||||
|
||||
|
@ -1876,7 +1997,7 @@ def write_to_pyo3_simple(type_info, f):
|
|||
if not type_info.is_simple:
|
||||
continue
|
||||
|
||||
rust_name = type_info.rust_sum_name
|
||||
rust_name = type_info.full_type_name
|
||||
f.write(
|
||||
f"""
|
||||
impl ToPyAst for ast::{rust_name} {{
|
||||
|
@ -1903,20 +2024,21 @@ def write_pyo3_wrapper(mod, type_info, namespace, f):
|
|||
Pyo3StructVisitor(namespace, f, type_info).visit(mod)
|
||||
|
||||
if namespace == "located":
|
||||
for type_info in type_info.values():
|
||||
if not type_info.is_simple or not type_info.is_sum:
|
||||
for info in type_info.values():
|
||||
if not info.is_simple or not info.is_sum:
|
||||
continue
|
||||
|
||||
rust_name = type_info.rust_sum_name
|
||||
rust_name = info.full_type_name
|
||||
inner_name = type_info[info.custom.name].full_type_name
|
||||
f.write(
|
||||
f"""
|
||||
impl ToPyWrapper for ast::{rust_name} {{
|
||||
impl ToPyWrapper for ast::{inner_name} {{
|
||||
#[inline]
|
||||
fn to_py_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
|
||||
match &self {{
|
||||
""",
|
||||
)
|
||||
for cons in type_info.type.value.types:
|
||||
for cons in info.type.value.types:
|
||||
f.write(
|
||||
f"Self::{cons.name} => Ok({rust_name}{cons.name}.to_object(py)),",
|
||||
)
|
||||
|
@ -1928,7 +2050,7 @@ def write_pyo3_wrapper(mod, type_info, namespace, f):
|
|||
""",
|
||||
)
|
||||
|
||||
for cons in type_info.type.value.types:
|
||||
for cons in info.type.value.types:
|
||||
f.write(
|
||||
f"""
|
||||
impl ToPyWrapper for ast::{rust_name}{cons.name} {{
|
||||
|
@ -1960,7 +2082,7 @@ def write_parse_def(mod, type_info, f):
|
|||
cons_name = rust_type_name(info.name)
|
||||
|
||||
f.write(f"""
|
||||
impl Parse for ast::{info.rust_sum_name} {{
|
||||
impl Parse for ast::{info.full_type_name} {{
|
||||
fn lex_starts_at(
|
||||
source: &str,
|
||||
offset: TextSize,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue