From 771be313a9f553fb88fc578e77be965e8f0a4c52 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 16 Nov 2022 18:59:38 +0900 Subject: [PATCH] Fix `match` codegen --- compiler/erg_compiler/codegen.rs | 192 +++++++++++++++------- compiler/erg_compiler/context/register.rs | 10 +- compiler/erg_compiler/lower.rs | 12 +- compiler/erg_compiler/optimize.rs | 30 +++- compiler/erg_parser/ast.rs | 2 + compiler/erg_parser/desugar.rs | 6 +- tests/control.er | 1 + 7 files changed, 173 insertions(+), 80 deletions(-) diff --git a/compiler/erg_compiler/codegen.rs b/compiler/erg_compiler/codegen.rs index 887f33ea..90972965 100644 --- a/compiler/erg_compiler/codegen.rs +++ b/compiler/erg_compiler/codegen.rs @@ -29,6 +29,7 @@ use erg_parser::ast::DefId; use erg_parser::ast::DefKind; use CommonOpcode::*; +use erg_parser::ast::PreDeclTypeSpec; use erg_parser::ast::TypeSpec; use erg_parser::ast::{NonDefaultParamSignature, ParamPattern, VarName}; use erg_parser::token::DOT; @@ -222,7 +223,7 @@ impl PyCodeGenerator { self.stack_dec(); } - fn emit_compare_op(&mut self, op: CompareOp) { + fn _emit_compare_op(&mut self, op: CompareOp) { self.write_instr(Opcode311::COMPARE_OP); self.write_arg(op as usize); self.stack_dec(); @@ -231,6 +232,57 @@ impl PyCodeGenerator { } } + /// shut down the interpreter + #[allow(dead_code)] + fn terminate(&mut self) { + self.emit_push_null(); + self.emit_load_name_instr(Identifier::public("exit")); + self.emit_load_const(1); + self.emit_precall_and_call(1); + self.stack_dec(); + } + + /// swap TOS and TOS1 + fn rot2(&mut self) { + if self.py_version.minor >= Some(11) { + self.write_instr(Opcode311::SWAP); + self.write_arg(2); + } else { + self.write_instr(Opcode310::ROT_TWO); + self.write_arg(0); + } + } + + fn dup_top(&mut self) { + if self.py_version.minor >= Some(11) { + self.write_instr(Opcode311::COPY); + self.write_arg(1); + } else { + self.write_instr(Opcode310::DUP_TOP); + self.write_arg(0); + } + self.stack_inc(); + } + + /// COPY(1) == DUP_TOP + fn copy(&mut self, i: usize) { + debug_power_assert!(i, >, 0); + if self.py_version.minor >= Some(11) { + self.write_instr(Opcode311::COPY); + self.write_arg(i); + } else { + todo!() + } + self.stack_inc(); + } + + /// 0 origin + #[allow(dead_code)] + fn peek_stack(&mut self, i: usize) { + self.copy(i + 1); + self.emit_print_expr(); + } + #[inline] fn jump_delta(&self, jump_to: usize) -> usize { if self.py_version.minor >= Some(10) { @@ -1252,19 +1304,7 @@ impl PyCodeGenerator { return; } if !self.in_op_loaded { - let mod_name = if self.py_version.minor >= Some(10) { - Identifier::public("_erg_std_prelude") - } else { - Identifier::public("_erg_std_prelude_old") - }; - self.emit_global_import_items( - mod_name, - vec![( - Identifier::public("in_operator"), - Some(Identifier::private("#in_operator")), - )], - ); - self.in_op_loaded = true; + self.load_in_op(); } self.emit_push_null(); self.emit_load_name_instr(Identifier::private("#in_operator")); @@ -1614,18 +1654,10 @@ impl PyCodeGenerator { let expr = args.remove(0); self.emit_expr(expr); let len = args.len(); - let mut absolute_jump_points = vec![]; + let mut jump_forward_points = vec![]; while let Some(expr) = args.try_remove(0) { - // パターンが複数ある場合引数を複製する、ただし最後はしない if len > 1 && !args.is_empty() { - if self.py_version.minor >= Some(11) { - self.write_instr(Opcode311::COPY); - self.write_arg(1); - } else { - self.write_instr(Opcode310::DUP_TOP); - self.write_arg(0); - } - self.stack_inc(); + self.dup_top(); } // compilerで型チェック済み(可読性が下がるため、matchでNamedは使えない) let mut lambda = enum_unwrap!(expr, Expr::Lambda); @@ -1634,8 +1666,11 @@ impl PyCodeGenerator { todo!("default values in match expression are not supported yet") } let param = lambda.params.non_defaults.remove(0); - let pop_jump_points = self.emit_match_pattern(param); + let pop_jump_points = self.emit_match_pattern(param, args.is_empty()); self.emit_frameless_block(lambda.body, Vec::new()); + // If we move on to the next arm, the stack size will increase + // so `self.stack_dec();` for now (+1 at the end). + self.stack_dec(); for pop_jump_point in pop_jump_points.into_iter() { let idx = if self.py_version.minor >= Some(11) { self.lasti() - pop_jump_point // - 2 @@ -1643,61 +1678,41 @@ impl PyCodeGenerator { self.lasti() + 2 }; self.calc_edit_jump(pop_jump_point + 1, idx); // jump to POP_TOP - absolute_jump_points.push(self.lasti()); + jump_forward_points.push(self.lasti()); self.write_instr(JUMP_FORWARD); // jump to the end self.write_arg(0); } } let lasti = self.lasti(); - for absolute_jump_point in absolute_jump_points.into_iter() { - self.calc_edit_jump(absolute_jump_point + 1, lasti - absolute_jump_point - 1); + for jump_point in jump_forward_points.into_iter() { + self.calc_edit_jump(jump_point + 1, lasti - jump_point - 1); } + self.stack_inc(); debug_assert_eq!(self.stack_len(), init_stack_len + 1); } - fn emit_match_pattern(&mut self, param: NonDefaultParamSignature) -> Vec { + fn emit_match_pattern( + &mut self, + param: NonDefaultParamSignature, + is_last_arm: bool, + ) -> Vec { log!(info "entered {}", fn_name!()); let mut pop_jump_points = vec![]; + if let Some(t_spec) = param.t_spec.map(|spec| spec.t_spec) { + // If it's the last arm, there's no need to inspect it + if !is_last_arm { + self.emit_match_guard(t_spec, &mut pop_jump_points); + } + } match param.pat { ParamPattern::VarName(name) => { let ident = Identifier::bare(None, name); self.emit_store_instr(ident, AccessKind::Name); } - ParamPattern::Array(arr) => { - let len = arr.len(); - self.write_instr(Opcode310::MATCH_SEQUENCE); - self.write_arg(0); - pop_jump_points.push(self.lasti()); - self.write_instr(Opcode310::POP_JUMP_IF_FALSE); - self.write_arg(0); - self.stack_dec(); - self.write_instr(Opcode310::GET_LEN); - self.write_arg(0); - self.emit_load_const(len); - self.emit_compare_op(CompareOp::EQ); - pop_jump_points.push(self.lasti()); - self.write_instr(Opcode310::POP_JUMP_IF_FALSE); - self.write_arg(0); - self.stack_dec(); - self.write_instr(Opcode310::UNPACK_SEQUENCE); - self.write_arg(len); - self.stack_inc_n(len - 1); - for elem in arr.elems.non_defaults { - pop_jump_points.append(&mut self.emit_match_pattern(elem)); - } - if !arr.elems.defaults.is_empty() { - todo!("default values in match are not supported yet") - } - } ParamPattern::Discard(_) => { - if let Some(t_spec) = param.t_spec.map(|spec| spec.t_spec) { - self.emit_match_guard(t_spec, &mut pop_jump_points) - } self.emit_pop_top(); } - _other => { - todo!() - } + _other => unreachable!(), } pop_jump_points } @@ -1705,8 +1720,8 @@ impl PyCodeGenerator { fn emit_match_guard(&mut self, t_spec: TypeSpec, pop_jump_points: &mut Vec) { #[allow(clippy::single_match)] match t_spec { - TypeSpec::Enum(enm) => { - let elems = enm + TypeSpec::Enum(enum_t) => { + let elems = enum_t .deconstruct() .0 .into_iter() @@ -1725,6 +1740,41 @@ impl PyCodeGenerator { // but the numbers are the same, only the way the jumping points are calculated is different. self.write_instr(Opcode310::POP_JUMP_IF_FALSE); // jump to the next case self.write_arg(0); + // self.stack_dec(); + } + TypeSpec::PreDeclTy(PreDeclTypeSpec::Simple(simple)) if simple.args.is_empty() => { + // arg null + // ↓ SWAP 1 + // null arg + // ↓ LOAD_NAME(in_operator) + // null arg in_operator + // ↓ SWAP 1 + // null in_operator arg + // ↓ LOAD_NAME(typ) + // null in_operator arg typ + self.emit_push_null(); + self.rot2(); + if !self.in_op_loaded { + self.load_in_op(); + } + self.emit_load_name_instr(Identifier::private("#in_operator")); + self.rot2(); + // TODO: DOT/not + let mut typ = Identifier::bare(Some(DOT), simple.ident.name); + // TODO: + typ.vi.py_name = match &typ.name.inspect()[..] { + "Int" => Some("int".into()), + "Float" => Some("float".into()), + _ => None, + }; + self.emit_load_name_instr(typ); + self.emit_precall_and_call(2); + self.stack_dec(); + pop_jump_points.push(self.lasti()); + // in 3.11, POP_JUMP_IF_FALSE is replaced with POP_JUMP_FORWARD_IF_FALSE + // but the numbers are the same, only the way the jumping points are calculated is different. + self.write_instr(Opcode310::POP_JUMP_IF_FALSE); // jump to the next case + self.write_arg(0); self.stack_dec(); } /*TypeSpec::Interval { op, lhs, rhs } => { @@ -2610,6 +2660,22 @@ impl PyCodeGenerator { self.record_type_loaded = true; } + fn load_in_op(&mut self) { + let mod_name = if self.py_version.minor >= Some(10) { + Identifier::public("_erg_std_prelude") + } else { + Identifier::public("_erg_std_prelude_old") + }; + self.emit_global_import_items( + mod_name, + vec![( + Identifier::public("in_operator"), + Some(Identifier::private("#in_operator")), + )], + ); + self.in_op_loaded = true; + } + fn load_prelude_py(&mut self) { self.emit_global_import_items( Identifier::public("sys"), diff --git a/compiler/erg_compiler/context/register.rs b/compiler/erg_compiler/context/register.rs index 09643a6c..3560240e 100644 --- a/compiler/erg_compiler/context/register.rs +++ b/compiler/erg_compiler/context/register.rs @@ -162,6 +162,9 @@ impl Context { ) -> TyCheckResult<()> { let ident = match &sig.pat { ast::VarPattern::Ident(ident) => ident, + ast::VarPattern::Discard(_) => { + return Ok(()); + } _ => todo!(), }; // already defined as const @@ -545,7 +548,8 @@ impl Context { pub(crate) fn preregister_def(&mut self, def: &ast::Def) -> TyCheckResult<()> { let id = Some(def.body.id); - let __name__ = def.sig.ident().unwrap().inspect(); + let ubar = Str::ever("_"); + let __name__ = def.sig.ident().map(|i| i.inspect()).unwrap_or(&ubar); match &def.sig { ast::Signature::Subr(sig) => { if sig.is_const() { @@ -597,7 +601,9 @@ impl Context { self.sub_unify(&const_t, &spec_t, def.body.loc(), None)?; } self.pop(); - self.register_gen_const(sig.ident().unwrap(), obj)?; + if let Some(ident) = sig.ident() { + self.register_gen_const(ident, obj)?; + } } _ => {} } diff --git a/compiler/erg_compiler/lower.rs b/compiler/erg_compiler/lower.rs index b75efbf2..ac8b282c 100644 --- a/compiler/erg_compiler/lower.rs +++ b/compiler/erg_compiler/lower.rs @@ -894,15 +894,13 @@ impl ASTLowerer { match self.lower_block(body.block) { Ok(block) => { let found_body_t = block.ref_t(); - let opt_expect_body_t = self - .ctx - .outer - .as_ref() - .unwrap() - .get_current_scope_var(sig.inspect().unwrap()) - .map(|vi| vi.t.clone()); + let outer = self.ctx.outer.as_ref().unwrap(); + let opt_expect_body_t = sig + .inspect() + .and_then(|name| outer.get_current_scope_var(name).map(|vi| vi.t.clone())); let ident = match &sig.pat { ast::VarPattern::Ident(ident) => ident, + ast::VarPattern::Discard(_) => ast::Identifier::UBAR, _ => unreachable!(), }; if let Some(expect_body_t) = opt_expect_body_t { diff --git a/compiler/erg_compiler/optimize.rs b/compiler/erg_compiler/optimize.rs index 32619fc2..8ed6fa52 100644 --- a/compiler/erg_compiler/optimize.rs +++ b/compiler/erg_compiler/optimize.rs @@ -1,19 +1,41 @@ +use crate::artifact::CompleteArtifact; use crate::error::CompileWarnings; -use crate::hir::HIR; +use crate::hir::*; +// use crate::erg_common::traits::Stream; #[derive(Debug)] pub struct HIROptimizer {} impl HIROptimizer { - pub fn fold_constants(&mut self, mut _hir: HIR) -> HIR { + pub fn optimize(hir: HIR) -> CompleteArtifact { + let mut optimizer = HIROptimizer {}; + optimizer.eliminate_dead_code(hir) + } + + fn _fold_constants(&mut self, mut _hir: HIR) -> HIR { todo!() } - pub fn eliminate_unused_variables(&mut self, mut _hir: HIR) -> (HIR, CompileWarnings) { + fn _eliminate_unused_variables(&mut self, mut _hir: HIR) -> (HIR, CompileWarnings) { todo!() } - pub fn eliminate_dead_code(&mut self, mut _hir: HIR) -> (HIR, CompileWarnings) { + fn eliminate_dead_code(&mut self, hir: HIR) -> CompleteArtifact { + CompleteArtifact::new( + self.eliminate_discarded_variables(hir), + CompileWarnings::empty(), + ) + } + + /// ```erg + /// _ = 1 + /// (a, _) = (1, True) + /// ``` + /// ↓ + /// ```erg + /// a = 1 + /// ``` + fn eliminate_discarded_variables(&mut self, mut _hir: HIR) -> HIR { todo!() } } diff --git a/compiler/erg_parser/ast.rs b/compiler/erg_parser/ast.rs index e8b43328..2074d1d5 100644 --- a/compiler/erg_parser/ast.rs +++ b/compiler/erg_parser/ast.rs @@ -2241,6 +2241,8 @@ impl From<&Identifier> for Field { } impl Identifier { + pub const UBAR: &Self = &Self::new(None, VarName::from_static("_")); + pub const fn new(dot: Option, name: VarName) -> Self { Self { dot, name } } diff --git a/compiler/erg_parser/desugar.rs b/compiler/erg_parser/desugar.rs index c08ea881..faca1686 100644 --- a/compiler/erg_parser/desugar.rs +++ b/compiler/erg_parser/desugar.rs @@ -466,7 +466,7 @@ impl Desugarer { ); } } - VarPattern::Ident(_i) => { + VarPattern::Ident(_) | VarPattern::Discard(_) => { let block = body .block .into_iter() @@ -476,7 +476,6 @@ impl Desugarer { let def = Def::new(Signature::Var(v), body); new.push(Expr::Def(def)); } - _ => {} }, Expr::Def(Def { sig: Signature::Subr(mut subr), @@ -582,11 +581,10 @@ impl Desugarer { ); } } - VarPattern::Ident(_ident) => { + VarPattern::Ident(_) | VarPattern::Discard(_) => { let def = Def::new(Signature::Var(sig.clone()), body); new_module.push(Expr::Def(def)); } - _ => {} } } diff --git a/tests/control.er b/tests/control.er index c061ab84..2c2779c5 100644 --- a/tests/control.er +++ b/tests/control.er @@ -13,6 +13,7 @@ if! cond: a = 1 _ = match a: (i: Int) -> i + (s: Int) -> 1 _ -> panic "unknown object" for! 0..<10, i =>