diff --git a/ast/asdl_rs.py b/ast/asdl_rs.py index 5ec4cfd..6b4f2e2 100755 --- a/ast/asdl_rs.py +++ b/ast/asdl_rs.py @@ -10,7 +10,7 @@ import re from argparse import ArgumentParser from pathlib import Path -from typing import Optional, Dict +from typing import Optional, Dict, Any import asdl @@ -30,7 +30,18 @@ BUILTIN_INT_NAMES = { "is_async": "bool", } -RUST_KEYWORDS = {"if", "while", "for", "return", "match", "try", "await", "yield"} +RUST_KEYWORDS = { + "if", + "while", + "for", + "return", + "match", + "try", + "await", + "yield", + "in", + "mod", +} def rust_field_name(name): @@ -80,34 +91,52 @@ def asdl_of(name, obj): class TypeInfo: - name: str + type: asdl.Type enum_name: Optional[str] has_user_data: Optional[bool] has_attributes: bool is_simple: bool - empty_field: bool children: set + fields: Optional[Any] boxed: bool - product: bool - has_expr: bool = False - def __init__(self, name): - self.name = name + def __init__(self, type): + self.type = type self.enum_name = None self.has_user_data = None self.has_attributes = False self.is_simple = False - self.empty_field = False self.children = set() + self.fields = None self.boxed = False - self.product = False - self.product_has_expr = False def __repr__(self): return f"" + @property + def name(self): + return self.type.name + + @property + def is_type(self): + return isinstance(self.type, asdl.Type) + + @property + def is_product(self): + return self.is_type and isinstance(self.type.value, asdl.Product) + + @property + def is_sum(self): + return self.is_type and isinstance(self.type.value, asdl.Sum) + + @property + def has_expr(self): + return self.is_product and any( + f.type != "identifier" for f in self.type.value.fields + ) + def no_cfg(self, typeinfo): - if self.product: + if self.is_product: return self.has_attributes elif self.enum_name: return typeinfo[self.enum_name].has_attributes @@ -199,42 +228,46 @@ class FindUserDataTypesVisitor(asdl.VisitorBase): info.determine_user_data(self.type_info, stack) def visitType(self, type): - self.type_info[type.name] = TypeInfo(type.name) - self.visit(type.value, type.name) + self.type_info[type.name] = TypeInfo(type) + self.visit(type.value, type) - def visitSum(self, sum, name): - info = self.type_info[name] - if is_simple(sum): + def visitSum(self, sum, type): + info = self.type_info[type.name] + info.is_simple = is_simple(sum) + for cons in sum.types: + self.visit(cons, type, info.is_simple) + + if info.is_simple: info.has_user_data = False - info.is_simple = True - else: - for t in sum.types: - t_info = TypeInfo(t.name) - t_info.enum_name = name - t_info.empty_field = not t.fields - self.type_info[t.name] = t_info - self.add_children(t.name, t.fields) - if len(sum.types) > 1: - info.boxed = True - if sum.attributes: - # attributes means located, which has the `range: R` field - info.has_user_data = True - info.has_attributes = True + return + + for t in sum.types: + self.add_children(t.name, t.fields) + + if len(sum.types) > 1: + info.boxed = True + if sum.attributes: + # attributes means located, which has the `range: R` field + info.has_user_data = True + info.has_attributes = True for variant in sum.types: - self.add_children(name, variant.fields) + self.add_children(type.name, variant.fields) - def visitProduct(self, product, name): - info = self.type_info[name] + def visitConstructor(self, cons, type, simple): + info = self.type_info[cons.name] = TypeInfo(cons) + info.enum_name = type.name + info.is_simple = simple + + def visitProduct(self, product, type): + info = self.type_info[type.name] if product.attributes: # attributes means located, which has the `range: R` field info.has_user_data = True info.has_attributes = True - info.has_expr = product_has_expr(product) if len(product.fields) > 2: info.boxed = True - info.product = True - self.add_children(name, product.fields) + self.add_children(type.name, product.fields) def add_children(self, name, fields): self.type_info[name].children.update( @@ -249,10 +282,6 @@ def rust_field(field_name): return field_name -def product_has_expr(product): - return any(f.type != "identifier" for f in product.fields) - - class StructVisitor(EmitVisitor): """Visitor to generate type-defs for AST.""" @@ -264,19 +293,19 @@ class StructVisitor(EmitVisitor): self.visit(dfn) def visitType(self, type, depth=0): - self.visit(type.value, type.name, depth) + self.visit(type.value, type, depth) - def visitSum(self, sum, name, depth): + def visitSum(self, sum, type, depth): if is_simple(sum): - self.simple_sum(sum, name, depth) + self.simple_sum(sum, type, depth) else: - self.sum_with_constructors(sum, name, depth) + self.sum_with_constructors(sum, type, depth) - (generics_applied,) = self.apply_generics(name, "R") + (generics_applied,) = self.apply_generics(type.name, "R") self.emit( f""" - impl{generics_applied} Node for {rust_type_name(name)}{generics_applied} {{ - const NAME: &'static str = "{name}"; + impl{generics_applied} Node for {rust_type_name(type.name)}{generics_applied} {{ + const NAME: &'static str = "{type.name}"; const FIELD_NAMES: &'static [&'static str] = &[]; }} """, @@ -292,19 +321,64 @@ class StructVisitor(EmitVisitor): else: self.emit("pub range: OptionalRange,", depth + 1) - def simple_sum(self, sum, name, depth): - rust_name = rust_type_name(name) + def simple_sum(self, sum, type, depth): + rust_name = rust_type_name(type.name) self.emit_attrs(depth) self.emit("#[derive(is_macro::Is, Copy, Hash, Eq)]", depth) self.emit(f"pub enum {rust_name} {{", depth) - for variant in sum.types: - self.emit(f"{variant.name},", depth + 1) + for cons in sum.types: + self.emit(f"{cons.name},", depth + 1) + self.emit("}", depth) + self.emit(f"impl {rust_name} {{", depth) + needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types) + if needs_escape: + prefix = rust_field_name(type.name) + "_" + else: + prefix = "" + for cons in sum.types: + self.emit( + textwrap.dedent( + f""" + #[inline] + pub const fn {prefix}{rust_field_name(cons.name)}(&self) -> Option<{rust_name}{cons.name}> {{ + match self {{ + {rust_name}::{cons.name} => Some({rust_name}{cons.name}), + _ => None, + }} + }} + """ + ), + depth, + ) self.emit("}", depth) self.emit("", depth) - def sum_with_constructors(self, sum, name, depth): - type_info = self.type_info[name] - rust_name = rust_type_name(name) + for cons in sum.types: + self.emit( + f""" + pub struct {rust_name}{cons.name}; + impl From<{rust_name}{cons.name}> for {rust_name} {{ + fn from(_: {rust_name}{cons.name}) -> Self {{ + {rust_name}::{cons.name} + }} + }} + impl Node for {rust_name}{cons.name} {{ + const NAME: &'static str = "{cons.name}"; + const FIELD_NAMES: &'static [&'static str] = &[]; + }} + impl std::cmp::PartialEq<{rust_name}> for {rust_name}{cons.name} {{ + #[inline] + fn eq(&self, other: &{rust_name}) -> bool {{ + matches!(other, {rust_name}::{cons.name}) + }} + }} + """, + 0, + ) + + def sum_with_constructors(self, sum, type, depth): + type_info = self.type_info[type.name] + rust_name = rust_type_name(type.name) # all the attributes right now are for location, so if it has attrs we # can just wrap it in Attributed<> @@ -376,7 +450,7 @@ class StructVisitor(EmitVisitor): if ( field_type and field_type.boxed - and (not (parent.product or field.seq) or field.opt) + and (not (parent.is_product or field.seq) or field.opt) ): typ = f"Box<{typ}>" if field.opt or ( @@ -394,9 +468,9 @@ class StructVisitor(EmitVisitor): name = rust_field(field.name) self.emit(f"{vis}{name}: {typ},", depth) - def visitProduct(self, product, name, depth): - type_info = self.type_info[name] - product_name = rust_type_name(name) + def visitProduct(self, product, type, depth): + type_info = self.type_info[type.name] + product_name = rust_type_name(type.name) self.emit_attrs(depth) self.emit(f"pub struct {product_name} {{", depth) @@ -410,7 +484,7 @@ class StructVisitor(EmitVisitor): self.emit( f""" impl Node for {product_name} {{ - const NAME: &'static str = "{name}"; + const NAME: &'static str = "{type.name}"; const FIELD_NAMES: &'static [&'static str] = &[ {', '.join(field_names)} ]; @@ -711,6 +785,158 @@ class VisitorModuleVisitor(EmitVisitor): VisitorTraitDefVisitor(self.file, self.type_info).visit(mod, depth) +class RangedDefVisitor(EmitVisitor): + def visitModule(self, mod): + for dfn in mod.dfns: + self.visit(dfn) + + def visitType(self, type, depth=0): + self.visit(type.value, type.name, depth) + + def visitSum(self, sum, name, depth): + info = self.type_info[name] + + self.emit_type_alias(info) + + if info.is_simple: + for ty in sum.types: + variant_info = self.type_info[ty.name] + self.emit_type_alias(variant_info) + return + + sum_match_arms = "" + + for ty in sum.types: + variant_info = self.type_info[ty.name] + sum_match_arms += ( + f" Self::{variant_info.rust_name}(node) => node.range()," + ) + self.emit_type_alias(variant_info) + self.emit_ranged_impl(variant_info) + + if not info.no_cfg(self.type_info): + self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) + + self.emit( + f""" + impl Ranged for crate::{info.rust_sum_name} {{ + fn range(&self) -> TextRange {{ + match self {{ + {sum_match_arms} + }} + }} + }} + """.lstrip(), + 0, + ) + + def visitProduct(self, product, name, depth): + info = self.type_info[name] + + self.emit_type_alias(info) + self.emit_ranged_impl(info) + + def emit_type_alias(self, info): + return # disable + generics = "" if info.is_simple else "::" + + self.emit( + f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};", + 0, + ) + self.emit("", 0) + + def emit_ranged_impl(self, info): + if not info.no_cfg(self.type_info): + self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) + + self.file.write( + f""" + impl Ranged for crate::generic::{info.rust_sum_name}:: {{ + fn range(&self) -> TextRange {{ + self.range + }} + }} + """.strip() + ) + + +class LocatedDefVisitor(EmitVisitor): + def visitModule(self, mod): + for dfn in mod.dfns: + self.visit(dfn) + + def visitType(self, type, depth=0): + self.visit(type.value, type.name, depth) + + def visitSum(self, sum, name, depth): + info = self.type_info[name] + + self.emit_type_alias(info) + + if info.is_simple: + for ty in sum.types: + variant_info = self.type_info[ty.name] + self.emit_type_alias(variant_info) + return + + sum_match_arms = "" + + for ty in sum.types: + variant_info = self.type_info[ty.name] + sum_match_arms += ( + f" Self::{variant_info.rust_name}(node) => node.range()," + ) + self.emit_type_alias(variant_info) + self.emit_located_impl(variant_info) + + if not info.no_cfg(self.type_info): + self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) + + self.emit( + f""" + impl Located for {info.rust_sum_name} {{ + fn range(&self) -> SourceRange {{ + match self {{ + {sum_match_arms} + }} + }} + }} + """.lstrip(), + 0, + ) + + def visitProduct(self, product, name, depth): + info = self.type_info[name] + + self.emit_type_alias(info) + self.emit_located_impl(info) + + def emit_type_alias(self, info): + generics = "" if info.is_simple else "::" + + self.emit( + f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};", + 0, + ) + self.emit("", 0) + + def emit_located_impl(self, info): + if not info.no_cfg(self.type_info): + self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) + + self.emit( + f""" + impl Located for {info.rust_sum_name} {{ + fn range(&self) -> SourceRange {{ + self.range + }} + }} + """, + 0, + ) + + class StdlibClassDefVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns: @@ -836,7 +1062,10 @@ class StdlibTraitImplVisitor(EmitVisitor): depth, ) self.emit("};", depth + 3) - self.emit("NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()", depth + 2) + self.emit( + "NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()", + depth + 2, + ) else: self.emit("match self {", depth + 2) for cons in sum.types: @@ -1011,138 +1240,6 @@ class StdlibTraitImplVisitor(EmitVisitor): return f"Node::ast_from_object(_vm, get_node_field(_vm, &_object, {name}, {json.dumps(typename)})?)?" -class RangedDefVisitor(EmitVisitor): - def visitModule(self, mod): - for dfn in mod.dfns: - self.visit(dfn) - - def visitType(self, type, depth=0): - self.visit(type.value, type.name, depth) - - def visitSum(self, sum, name, depth): - info = self.type_info[name] - - if info.is_simple: - return - - sum_match_arms = "" - - for ty in sum.types: - variant_info = self.type_info[ty.name] - sum_match_arms += ( - f" Self::{variant_info.rust_name}(node) => node.range()," - ) - self.emit_ranged_impl(variant_info) - - if not info.no_cfg(self.type_info): - self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) - - self.emit( - f""" - impl Ranged for crate::{info.rust_sum_name} {{ - fn range(&self) -> TextRange {{ - match self {{ - {sum_match_arms} - }} - }} - }} - """.lstrip(), - 0, - ) - - def visitProduct(self, product, name, depth): - info = self.type_info[name] - - self.emit_ranged_impl(info) - - def emit_ranged_impl(self, info): - if not info.no_cfg(self.type_info): - self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) - - self.file.write( - f""" - impl Ranged for crate::generic::{info.rust_sum_name}:: {{ - fn range(&self) -> TextRange {{ - self.range - }} - }} - """.strip() - ) - - -class LocatedDefVisitor(EmitVisitor): - def visitModule(self, mod): - for dfn in mod.dfns: - self.visit(dfn) - - def visitType(self, type, depth=0): - self.visit(type.value, type.name, depth) - - def visitSum(self, sum, name, depth): - info = self.type_info[name] - - self.emit_type_alias(info) - - if info.is_simple: - return - - sum_match_arms = "" - - for ty in sum.types: - variant_info = self.type_info[ty.name] - sum_match_arms += ( - f" Self::{variant_info.rust_name}(node) => node.range()," - ) - self.emit_type_alias(variant_info) - self.emit_located_impl(variant_info) - - if not info.no_cfg(self.type_info): - self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) - - self.emit( - f""" - impl Located for {info.rust_sum_name} {{ - fn range(&self) -> SourceRange {{ - match self {{ - {sum_match_arms} - }} - }} - }} - """.lstrip(), - 0, - ) - - def visitProduct(self, product, name, depth): - info = self.type_info[name] - - self.emit_type_alias(info) - self.emit_located_impl(info) - - def emit_type_alias(self, info): - generics = "" if info.is_simple else "::" - - self.emit( - f"pub type {info.rust_sum_name} = crate::generic::{info.rust_sum_name}{generics};", - 0, - ) - self.emit("", 0) - - def emit_located_impl(self, info): - if not info.no_cfg(self.type_info): - self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0) - - self.emit( - f""" - impl Located for {info.rust_sum_name} {{ - fn range(&self) -> SourceRange {{ - self.range - }} - }} - """, - 0, - ) - - class ChainOfVisitors: def __init__(self, *visitors): self.visitors = visitors @@ -1211,16 +1308,18 @@ def main( type_info = {} FindUserDataTypesVisitor(type_info).visit(mod) + from functools import partial as p + for filename, write in [ - ("generic", write_ast_def), - ("fold", write_fold_def), - ("ranged", write_ranged_def), - ("located", write_located_def), - ("visitor", write_visitor_def), + ("generic", p(write_ast_def, mod, type_info)), + ("fold", p(write_fold_def, mod, type_info)), + ("ranged", p(write_ranged_def, mod, type_info)), + ("located", p(write_located_def, mod, type_info)), + ("visitor", p(write_visitor_def, mod, type_info)), ]: with (ast_dir / f"{filename}.rs").open("w") as f: f.write(auto_gen_msg) - write(mod, type_info, f) + write(f) with module_filename.open("w") as module_file: module_file.write(auto_gen_msg) diff --git a/ast/src/gen/generic.rs b/ast/src/gen/generic.rs index 40643a2..fce8459 100644 --- a/ast/src/gen/generic.rs +++ b/ast/src/gen/generic.rs @@ -1169,6 +1169,82 @@ pub enum ExprContext { Store, Del, } +impl ExprContext { + #[inline] + pub const fn load(&self) -> Option { + match self { + ExprContext::Load => Some(ExprContextLoad), + _ => None, + } + } + + #[inline] + pub const fn store(&self) -> Option { + match self { + ExprContext::Store => Some(ExprContextStore), + _ => None, + } + } + + #[inline] + pub const fn del(&self) -> Option { + match self { + ExprContext::Del => Some(ExprContextDel), + _ => None, + } + } +} + +pub struct ExprContextLoad; +impl From for ExprContext { + fn from(_: ExprContextLoad) -> Self { + ExprContext::Load + } +} +impl Node for ExprContextLoad { + const NAME: &'static str = "Load"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for ExprContextLoad { + #[inline] + fn eq(&self, other: &ExprContext) -> bool { + matches!(other, ExprContext::Load) + } +} + +pub struct ExprContextStore; +impl From for ExprContext { + fn from(_: ExprContextStore) -> Self { + ExprContext::Store + } +} +impl Node for ExprContextStore { + const NAME: &'static str = "Store"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for ExprContextStore { + #[inline] + fn eq(&self, other: &ExprContext) -> bool { + matches!(other, ExprContext::Store) + } +} + +pub struct ExprContextDel; +impl From for ExprContext { + fn from(_: ExprContextDel) -> Self { + ExprContext::Del + } +} +impl Node for ExprContextDel { + const NAME: &'static str = "Del"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for ExprContextDel { + #[inline] + fn eq(&self, other: &ExprContext) -> bool { + matches!(other, ExprContext::Del) + } +} impl Node for ExprContext { const NAME: &'static str = "expr_context"; @@ -1180,6 +1256,57 @@ pub enum Boolop { And, Or, } +impl Boolop { + #[inline] + pub const fn and(&self) -> Option { + match self { + Boolop::And => Some(BoolopAnd), + _ => None, + } + } + + #[inline] + pub const fn or(&self) -> Option { + match self { + Boolop::Or => Some(BoolopOr), + _ => None, + } + } +} + +pub struct BoolopAnd; +impl From for Boolop { + fn from(_: BoolopAnd) -> Self { + Boolop::And + } +} +impl Node for BoolopAnd { + const NAME: &'static str = "And"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for BoolopAnd { + #[inline] + fn eq(&self, other: &Boolop) -> bool { + matches!(other, Boolop::And) + } +} + +pub struct BoolopOr; +impl From for Boolop { + fn from(_: BoolopOr) -> Self { + Boolop::Or + } +} +impl Node for BoolopOr { + const NAME: &'static str = "Or"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for BoolopOr { + #[inline] + fn eq(&self, other: &Boolop) -> bool { + matches!(other, Boolop::Or) + } +} impl Node for Boolop { const NAME: &'static str = "boolop"; @@ -1202,6 +1329,332 @@ pub enum Operator { BitAnd, FloorDiv, } +impl Operator { + #[inline] + pub const fn operator_add(&self) -> Option { + match self { + Operator::Add => Some(OperatorAdd), + _ => None, + } + } + + #[inline] + pub const fn operator_sub(&self) -> Option { + match self { + Operator::Sub => Some(OperatorSub), + _ => None, + } + } + + #[inline] + pub const fn operator_mult(&self) -> Option { + match self { + Operator::Mult => Some(OperatorMult), + _ => None, + } + } + + #[inline] + pub const fn operator_mat_mult(&self) -> Option { + match self { + Operator::MatMult => Some(OperatorMatMult), + _ => None, + } + } + + #[inline] + pub const fn operator_div(&self) -> Option { + match self { + Operator::Div => Some(OperatorDiv), + _ => None, + } + } + + #[inline] + pub const fn operator_mod(&self) -> Option { + match self { + Operator::Mod => Some(OperatorMod), + _ => None, + } + } + + #[inline] + pub const fn operator_pow(&self) -> Option { + match self { + Operator::Pow => Some(OperatorPow), + _ => None, + } + } + + #[inline] + pub const fn operator_l_shift(&self) -> Option { + match self { + Operator::LShift => Some(OperatorLShift), + _ => None, + } + } + + #[inline] + pub const fn operator_r_shift(&self) -> Option { + match self { + Operator::RShift => Some(OperatorRShift), + _ => None, + } + } + + #[inline] + pub const fn operator_bit_or(&self) -> Option { + match self { + Operator::BitOr => Some(OperatorBitOr), + _ => None, + } + } + + #[inline] + pub const fn operator_bit_xor(&self) -> Option { + match self { + Operator::BitXor => Some(OperatorBitXor), + _ => None, + } + } + + #[inline] + pub const fn operator_bit_and(&self) -> Option { + match self { + Operator::BitAnd => Some(OperatorBitAnd), + _ => None, + } + } + + #[inline] + pub const fn operator_floor_div(&self) -> Option { + match self { + Operator::FloorDiv => Some(OperatorFloorDiv), + _ => None, + } + } +} + +pub struct OperatorAdd; +impl From for Operator { + fn from(_: OperatorAdd) -> Self { + Operator::Add + } +} +impl Node for OperatorAdd { + const NAME: &'static str = "Add"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorAdd { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::Add) + } +} + +pub struct OperatorSub; +impl From for Operator { + fn from(_: OperatorSub) -> Self { + Operator::Sub + } +} +impl Node for OperatorSub { + const NAME: &'static str = "Sub"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorSub { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::Sub) + } +} + +pub struct OperatorMult; +impl From for Operator { + fn from(_: OperatorMult) -> Self { + Operator::Mult + } +} +impl Node for OperatorMult { + const NAME: &'static str = "Mult"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorMult { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::Mult) + } +} + +pub struct OperatorMatMult; +impl From for Operator { + fn from(_: OperatorMatMult) -> Self { + Operator::MatMult + } +} +impl Node for OperatorMatMult { + const NAME: &'static str = "MatMult"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorMatMult { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::MatMult) + } +} + +pub struct OperatorDiv; +impl From for Operator { + fn from(_: OperatorDiv) -> Self { + Operator::Div + } +} +impl Node for OperatorDiv { + const NAME: &'static str = "Div"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorDiv { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::Div) + } +} + +pub struct OperatorMod; +impl From for Operator { + fn from(_: OperatorMod) -> Self { + Operator::Mod + } +} +impl Node for OperatorMod { + const NAME: &'static str = "Mod"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorMod { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::Mod) + } +} + +pub struct OperatorPow; +impl From for Operator { + fn from(_: OperatorPow) -> Self { + Operator::Pow + } +} +impl Node for OperatorPow { + const NAME: &'static str = "Pow"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorPow { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::Pow) + } +} + +pub struct OperatorLShift; +impl From for Operator { + fn from(_: OperatorLShift) -> Self { + Operator::LShift + } +} +impl Node for OperatorLShift { + const NAME: &'static str = "LShift"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorLShift { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::LShift) + } +} + +pub struct OperatorRShift; +impl From for Operator { + fn from(_: OperatorRShift) -> Self { + Operator::RShift + } +} +impl Node for OperatorRShift { + const NAME: &'static str = "RShift"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorRShift { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::RShift) + } +} + +pub struct OperatorBitOr; +impl From for Operator { + fn from(_: OperatorBitOr) -> Self { + Operator::BitOr + } +} +impl Node for OperatorBitOr { + const NAME: &'static str = "BitOr"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorBitOr { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::BitOr) + } +} + +pub struct OperatorBitXor; +impl From for Operator { + fn from(_: OperatorBitXor) -> Self { + Operator::BitXor + } +} +impl Node for OperatorBitXor { + const NAME: &'static str = "BitXor"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorBitXor { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::BitXor) + } +} + +pub struct OperatorBitAnd; +impl From for Operator { + fn from(_: OperatorBitAnd) -> Self { + Operator::BitAnd + } +} +impl Node for OperatorBitAnd { + const NAME: &'static str = "BitAnd"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorBitAnd { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::BitAnd) + } +} + +pub struct OperatorFloorDiv; +impl From for Operator { + fn from(_: OperatorFloorDiv) -> Self { + Operator::FloorDiv + } +} +impl Node for OperatorFloorDiv { + const NAME: &'static str = "FloorDiv"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for OperatorFloorDiv { + #[inline] + fn eq(&self, other: &Operator) -> bool { + matches!(other, Operator::FloorDiv) + } +} impl Node for Operator { const NAME: &'static str = "operator"; @@ -1215,6 +1668,107 @@ pub enum Unaryop { UAdd, USub, } +impl Unaryop { + #[inline] + pub const fn invert(&self) -> Option { + match self { + Unaryop::Invert => Some(UnaryopInvert), + _ => None, + } + } + + #[inline] + pub const fn not(&self) -> Option { + match self { + Unaryop::Not => Some(UnaryopNot), + _ => None, + } + } + + #[inline] + pub const fn u_add(&self) -> Option { + match self { + Unaryop::UAdd => Some(UnaryopUAdd), + _ => None, + } + } + + #[inline] + pub const fn u_sub(&self) -> Option { + match self { + Unaryop::USub => Some(UnaryopUSub), + _ => None, + } + } +} + +pub struct UnaryopInvert; +impl From for Unaryop { + fn from(_: UnaryopInvert) -> Self { + Unaryop::Invert + } +} +impl Node for UnaryopInvert { + const NAME: &'static str = "Invert"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for UnaryopInvert { + #[inline] + fn eq(&self, other: &Unaryop) -> bool { + matches!(other, Unaryop::Invert) + } +} + +pub struct UnaryopNot; +impl From for Unaryop { + fn from(_: UnaryopNot) -> Self { + Unaryop::Not + } +} +impl Node for UnaryopNot { + const NAME: &'static str = "Not"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for UnaryopNot { + #[inline] + fn eq(&self, other: &Unaryop) -> bool { + matches!(other, Unaryop::Not) + } +} + +pub struct UnaryopUAdd; +impl From for Unaryop { + fn from(_: UnaryopUAdd) -> Self { + Unaryop::UAdd + } +} +impl Node for UnaryopUAdd { + const NAME: &'static str = "UAdd"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for UnaryopUAdd { + #[inline] + fn eq(&self, other: &Unaryop) -> bool { + matches!(other, Unaryop::UAdd) + } +} + +pub struct UnaryopUSub; +impl From for Unaryop { + fn from(_: UnaryopUSub) -> Self { + Unaryop::USub + } +} +impl Node for UnaryopUSub { + const NAME: &'static str = "USub"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for UnaryopUSub { + #[inline] + fn eq(&self, other: &Unaryop) -> bool { + matches!(other, Unaryop::USub) + } +} impl Node for Unaryop { const NAME: &'static str = "unaryop"; @@ -1234,6 +1788,257 @@ pub enum Cmpop { In, NotIn, } +impl Cmpop { + #[inline] + pub const fn cmpop_eq(&self) -> Option { + match self { + Cmpop::Eq => Some(CmpopEq), + _ => None, + } + } + + #[inline] + pub const fn cmpop_not_eq(&self) -> Option { + match self { + Cmpop::NotEq => Some(CmpopNotEq), + _ => None, + } + } + + #[inline] + pub const fn cmpop_lt(&self) -> Option { + match self { + Cmpop::Lt => Some(CmpopLt), + _ => None, + } + } + + #[inline] + pub const fn cmpop_lt_e(&self) -> Option { + match self { + Cmpop::LtE => Some(CmpopLtE), + _ => None, + } + } + + #[inline] + pub const fn cmpop_gt(&self) -> Option { + match self { + Cmpop::Gt => Some(CmpopGt), + _ => None, + } + } + + #[inline] + pub const fn cmpop_gt_e(&self) -> Option { + match self { + Cmpop::GtE => Some(CmpopGtE), + _ => None, + } + } + + #[inline] + pub const fn cmpop_is(&self) -> Option { + match self { + Cmpop::Is => Some(CmpopIs), + _ => None, + } + } + + #[inline] + pub const fn cmpop_is_not(&self) -> Option { + match self { + Cmpop::IsNot => Some(CmpopIsNot), + _ => None, + } + } + + #[inline] + pub const fn cmpop_in(&self) -> Option { + match self { + Cmpop::In => Some(CmpopIn), + _ => None, + } + } + + #[inline] + pub const fn cmpop_not_in(&self) -> Option { + match self { + Cmpop::NotIn => Some(CmpopNotIn), + _ => None, + } + } +} + +pub struct CmpopEq; +impl From for Cmpop { + fn from(_: CmpopEq) -> Self { + Cmpop::Eq + } +} +impl Node for CmpopEq { + const NAME: &'static str = "Eq"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for CmpopEq { + #[inline] + fn eq(&self, other: &Cmpop) -> bool { + matches!(other, Cmpop::Eq) + } +} + +pub struct CmpopNotEq; +impl From for Cmpop { + fn from(_: CmpopNotEq) -> Self { + Cmpop::NotEq + } +} +impl Node for CmpopNotEq { + const NAME: &'static str = "NotEq"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for CmpopNotEq { + #[inline] + fn eq(&self, other: &Cmpop) -> bool { + matches!(other, Cmpop::NotEq) + } +} + +pub struct CmpopLt; +impl From for Cmpop { + fn from(_: CmpopLt) -> Self { + Cmpop::Lt + } +} +impl Node for CmpopLt { + const NAME: &'static str = "Lt"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for CmpopLt { + #[inline] + fn eq(&self, other: &Cmpop) -> bool { + matches!(other, Cmpop::Lt) + } +} + +pub struct CmpopLtE; +impl From for Cmpop { + fn from(_: CmpopLtE) -> Self { + Cmpop::LtE + } +} +impl Node for CmpopLtE { + const NAME: &'static str = "LtE"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for CmpopLtE { + #[inline] + fn eq(&self, other: &Cmpop) -> bool { + matches!(other, Cmpop::LtE) + } +} + +pub struct CmpopGt; +impl From for Cmpop { + fn from(_: CmpopGt) -> Self { + Cmpop::Gt + } +} +impl Node for CmpopGt { + const NAME: &'static str = "Gt"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for CmpopGt { + #[inline] + fn eq(&self, other: &Cmpop) -> bool { + matches!(other, Cmpop::Gt) + } +} + +pub struct CmpopGtE; +impl From for Cmpop { + fn from(_: CmpopGtE) -> Self { + Cmpop::GtE + } +} +impl Node for CmpopGtE { + const NAME: &'static str = "GtE"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for CmpopGtE { + #[inline] + fn eq(&self, other: &Cmpop) -> bool { + matches!(other, Cmpop::GtE) + } +} + +pub struct CmpopIs; +impl From for Cmpop { + fn from(_: CmpopIs) -> Self { + Cmpop::Is + } +} +impl Node for CmpopIs { + const NAME: &'static str = "Is"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for CmpopIs { + #[inline] + fn eq(&self, other: &Cmpop) -> bool { + matches!(other, Cmpop::Is) + } +} + +pub struct CmpopIsNot; +impl From for Cmpop { + fn from(_: CmpopIsNot) -> Self { + Cmpop::IsNot + } +} +impl Node for CmpopIsNot { + const NAME: &'static str = "IsNot"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for CmpopIsNot { + #[inline] + fn eq(&self, other: &Cmpop) -> bool { + matches!(other, Cmpop::IsNot) + } +} + +pub struct CmpopIn; +impl From for Cmpop { + fn from(_: CmpopIn) -> Self { + Cmpop::In + } +} +impl Node for CmpopIn { + const NAME: &'static str = "In"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for CmpopIn { + #[inline] + fn eq(&self, other: &Cmpop) -> bool { + matches!(other, Cmpop::In) + } +} + +pub struct CmpopNotIn; +impl From for Cmpop { + fn from(_: CmpopNotIn) -> Self { + Cmpop::NotIn + } +} +impl Node for CmpopNotIn { + const NAME: &'static str = "NotIn"; + const FIELD_NAMES: &'static [&'static str] = &[]; +} +impl std::cmp::PartialEq for CmpopNotIn { + #[inline] + fn eq(&self, other: &Cmpop) -> bool { + matches!(other, Cmpop::NotIn) + } +} impl Node for Cmpop { const NAME: &'static str = "cmpop"; diff --git a/ast/src/gen/located.rs b/ast/src/gen/located.rs index 32f01e7..e5f278f 100644 --- a/ast/src/gen/located.rs +++ b/ast/src/gen/located.rs @@ -560,14 +560,78 @@ impl Located for Expr { pub type ExprContext = crate::generic::ExprContext; +pub type ExprContextLoad = crate::generic::ExprContextLoad; + +pub type ExprContextStore = crate::generic::ExprContextStore; + +pub type ExprContextDel = crate::generic::ExprContextDel; + pub type Boolop = crate::generic::Boolop; +pub type BoolopAnd = crate::generic::BoolopAnd; + +pub type BoolopOr = crate::generic::BoolopOr; + pub type Operator = crate::generic::Operator; +pub type OperatorAdd = crate::generic::OperatorAdd; + +pub type OperatorSub = crate::generic::OperatorSub; + +pub type OperatorMult = crate::generic::OperatorMult; + +pub type OperatorMatMult = crate::generic::OperatorMatMult; + +pub type OperatorDiv = crate::generic::OperatorDiv; + +pub type OperatorMod = crate::generic::OperatorMod; + +pub type OperatorPow = crate::generic::OperatorPow; + +pub type OperatorLShift = crate::generic::OperatorLShift; + +pub type OperatorRShift = crate::generic::OperatorRShift; + +pub type OperatorBitOr = crate::generic::OperatorBitOr; + +pub type OperatorBitXor = crate::generic::OperatorBitXor; + +pub type OperatorBitAnd = crate::generic::OperatorBitAnd; + +pub type OperatorFloorDiv = crate::generic::OperatorFloorDiv; + pub type Unaryop = crate::generic::Unaryop; +pub type UnaryopInvert = crate::generic::UnaryopInvert; + +pub type UnaryopNot = crate::generic::UnaryopNot; + +pub type UnaryopUAdd = crate::generic::UnaryopUAdd; + +pub type UnaryopUSub = crate::generic::UnaryopUSub; + pub type Cmpop = crate::generic::Cmpop; +pub type CmpopEq = crate::generic::CmpopEq; + +pub type CmpopNotEq = crate::generic::CmpopNotEq; + +pub type CmpopLt = crate::generic::CmpopLt; + +pub type CmpopLtE = crate::generic::CmpopLtE; + +pub type CmpopGt = crate::generic::CmpopGt; + +pub type CmpopGtE = crate::generic::CmpopGtE; + +pub type CmpopIs = crate::generic::CmpopIs; + +pub type CmpopIsNot = crate::generic::CmpopIsNot; + +pub type CmpopIn = crate::generic::CmpopIn; + +pub type CmpopNotIn = crate::generic::CmpopNotIn; + pub type Comprehension = crate::generic::Comprehension; #[cfg(feature = "all-nodes-with-ranges")] diff --git a/scripts/update_asdl.sh b/scripts/update_asdl.sh index 72bcf32..f2e3db2 100755 --- a/scripts/update_asdl.sh +++ b/scripts/update_asdl.sh @@ -3,5 +3,6 @@ set -e cd "$(dirname "$(dirname "$0")")" +# rm ast/src/gen/*.rs python ast/asdl_rs.py --ast-dir ast/src/gen/ --module-file ../RustPython/vm/src/stdlib/ast/gen.rs ast/Python.asdl rustfmt ast/src/gen/*.rs ../RustPython/vm/src/stdlib/ast/gen.rs