From bd29985cc75f474a16033a93afb3f3dbd066435b Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 9 Nov 2022 19:32:46 +0900 Subject: [PATCH] Implement match guard (literal) --- compiler/erg_compiler/codegen.rs | 59 +++++++++++++++-------- compiler/erg_parser/ast.rs | 4 ++ compiler/erg_parser/desugar.rs | 80 +++++++++++++++++++++++++------- 3 files changed, 106 insertions(+), 37 deletions(-) diff --git a/compiler/erg_compiler/codegen.rs b/compiler/erg_compiler/codegen.rs index 10cd5723..b252300a 100644 --- a/compiler/erg_compiler/codegen.rs +++ b/compiler/erg_compiler/codegen.rs @@ -24,10 +24,12 @@ use erg_common::{ debug_power_assert, enum_unwrap, fn_name, fn_name_full, impl_stream_for_wrapper, log, switch_unreachable, }; +use erg_parser::ast::ConstExpr; use erg_parser::ast::DefId; use erg_parser::ast::DefKind; use CommonOpcode::*; +use erg_parser::ast::TypeSpec; use erg_parser::ast::{NonDefaultParamSignature, ParamPattern, VarName}; use erg_parser::token::DOT; use erg_parser::token::EQUAL; @@ -1603,8 +1605,8 @@ impl PyCodeGenerator { if !lambda.params.defaults.is_empty() { todo!("default values in match expression are not supported yet") } - let pat = lambda.params.non_defaults.remove(0).pat; - let pop_jump_points = self.emit_match_pattern(pat); + let param = lambda.params.non_defaults.remove(0); + let pop_jump_points = self.emit_match_pattern(param); self.emit_frameless_block(lambda.body, Vec::new()); for pop_jump_point in pop_jump_points.into_iter() { let idx = if self.py_version.minor >= Some(11) { @@ -1624,29 +1626,14 @@ impl PyCodeGenerator { } } - fn emit_match_pattern(&mut self, pat: ParamPattern) -> Vec { + fn emit_match_pattern(&mut self, param: NonDefaultParamSignature) -> Vec { log!(info "entered {}", fn_name!()); let mut pop_jump_points = vec![]; - match pat { + match param.pat { ParamPattern::VarName(name) => { let ident = Identifier::bare(None, name); self.emit_store_instr(ident, AccessKind::Name); } - ParamPattern::Lit(lit) => { - let value = { - let t = type_from_token_kind(lit.token.kind); - ValueObj::from_str(t, lit.token.content).unwrap() - }; - self.emit_load_const(value); - self.emit_compare_op(CompareOp::EQ); - pop_jump_points.push(self.lasti()); - // in 3.11, POP_JUMP_IF_FALSE is replaced with POP_JUMP_FORWARD_IF_FALSE - // but the numbers are the same, only the way the jumping points are calculated is different. - self.write_instr(Opcode310::POP_JUMP_IF_FALSE); // jump to the next case - self.write_arg(0); - self.emit_pop_top(); - self.stack_dec(); - } ParamPattern::Array(arr) => { let len = arr.len(); self.write_instr(Opcode310::MATCH_SEQUENCE); @@ -1667,13 +1654,16 @@ impl PyCodeGenerator { self.write_arg(len); self.stack_inc_n(len - 1); for elem in arr.elems.non_defaults { - pop_jump_points.append(&mut self.emit_match_pattern(elem.pat)); + pop_jump_points.append(&mut self.emit_match_pattern(elem)); } if !arr.elems.defaults.is_empty() { todo!("default values in match are not supported yet") } } ParamPattern::Discard(_) => { + if let Some(t_spec) = param.t_spec.map(|spec| spec.t_spec) { + self.emit_match_guard(t_spec, &mut pop_jump_points) + } self.emit_pop_top(); } _other => { @@ -1683,6 +1673,35 @@ impl PyCodeGenerator { pop_jump_points } + fn emit_match_guard(&mut self, t_spec: TypeSpec, pop_jump_points: &mut Vec) { + #[allow(clippy::single_match)] + match t_spec { + TypeSpec::Enum(enm) => { + let (mut elems, ..) = enm.deconstruct(); + if elems.len() != 1 { + todo!() + } + let ConstExpr::Lit(lit) = elems.remove(0).expr else { + todo!() + }; + let value = { + let t = type_from_token_kind(lit.token.kind); + ValueObj::from_str(t, lit.token.content).unwrap() + }; + self.emit_load_const(value); + self.emit_compare_op(CompareOp::EQ); + pop_jump_points.push(self.lasti()); + // in 3.11, POP_JUMP_IF_FALSE is replaced with POP_JUMP_FORWARD_IF_FALSE + // but the numbers are the same, only the way the jumping points are calculated is different. + self.write_instr(Opcode310::POP_JUMP_IF_FALSE); // jump to the next case + self.write_arg(0); + self.stack_dec(); + } + // TODO: + _ => {} + } + } + fn emit_with_instr_311(&mut self, args: Args) { log!(info "entered {}", fn_name!()); let mut args = args; diff --git a/compiler/erg_parser/ast.rs b/compiler/erg_parser/ast.rs index ec879cf4..6d2a87e1 100644 --- a/compiler/erg_parser/ast.rs +++ b/compiler/erg_parser/ast.rs @@ -1529,6 +1529,10 @@ impl ConstArgs { } } + pub fn deconstruct(self) -> (Vec, Vec, Option<(Token, Token)>) { + (self.pos_args, self.kw_args, self.paren) + } + pub const fn empty() -> Self { Self::new(vec![], vec![], None) } diff --git a/compiler/erg_parser/desugar.rs b/compiler/erg_parser/desugar.rs index 786011ba..c08ea881 100644 --- a/compiler/erg_parser/desugar.rs +++ b/compiler/erg_parser/desugar.rs @@ -63,7 +63,7 @@ impl Desugarer { module.into_iter().map(desugar).collect() } - fn perform_desugar(desugar: impl Fn(Expr) -> Expr, expr: Expr) -> Expr { + fn perform_desugar(mut desugar: impl FnMut(Expr) -> Expr, expr: Expr) -> Expr { match expr { Expr::Record(record) => match record { Record::Normal(rec) => { @@ -77,7 +77,7 @@ impl Desugarer { RecordAttrs::new(new_attrs), ))) } - _ => todo!(), + shorten => Expr::Record(shorten), }, Expr::DataPack(pack) => { let class = desugar(*pack.class); @@ -361,6 +361,20 @@ impl Desugarer { (buf_name, pat) } + fn rec_desugar_lambda_pattern(&mut self, expr: Expr) -> Expr { + match expr { + Expr::Lambda(mut lambda) => { + let non_defaults = lambda.sig.params.non_defaults.iter_mut(); + for param in non_defaults { + self.desugar_nd_param(param, &mut lambda.body); + } + Expr::Lambda(lambda) + } + expr => Self::perform_desugar(|ex| self.rec_desugar_lambda_pattern(ex), expr), + } + } + + // TODO: nested function pattern /// `[i, j] = [1, 2]` -> `i = 1; j = 2` /// `[i, j] = l` -> `i = l[0]; j = l[1]` /// `[i, [j, k]] = l` -> `i = l[0]; j = l[1][0]; k = l[1][1]` @@ -377,7 +391,12 @@ impl Desugarer { VarPattern::Tuple(tup) => { let (buf_name, buf_sig) = self.gen_buf_name_and_sig(v.ln_begin().unwrap(), v.t_spec); - let buf_def = Def::new(buf_sig, body); + let block = body + .block + .into_iter() + .map(|ex| self.rec_desugar_lambda_pattern(ex)) + .collect(); + let buf_def = Def::new(buf_sig, DefBody::new(body.op, block, body.id)); new.push(Expr::Def(buf_def)); for (n, elem) in tup.elems.iter().enumerate() { self.desugar_nested_var_pattern( @@ -391,7 +410,12 @@ impl Desugarer { VarPattern::Array(arr) => { let (buf_name, buf_sig) = self.gen_buf_name_and_sig(v.ln_begin().unwrap(), v.t_spec); - let buf_def = Def::new(buf_sig, body); + let block = body + .block + .into_iter() + .map(|ex| self.rec_desugar_lambda_pattern(ex)) + .collect(); + let buf_def = Def::new(buf_sig, DefBody::new(body.op, block, body.id)); new.push(Expr::Def(buf_def)); for (n, elem) in arr.elems.iter().enumerate() { self.desugar_nested_var_pattern( @@ -405,7 +429,12 @@ impl Desugarer { VarPattern::Record(rec) => { let (buf_name, buf_sig) = self.gen_buf_name_and_sig(v.ln_begin().unwrap(), v.t_spec); - let buf_def = Def::new(buf_sig, body); + let block = body + .block + .into_iter() + .map(|ex| self.rec_desugar_lambda_pattern(ex)) + .collect(); + let buf_def = Def::new(buf_sig, DefBody::new(body.op, block, body.id)); new.push(Expr::Def(buf_def)); for VarRecordAttr { lhs, rhs } in rec.attrs.iter() { self.desugar_nested_var_pattern( @@ -421,7 +450,12 @@ impl Desugarer { v.ln_begin().unwrap(), Some(pack.class.clone()), // TODO: これだとvの型指定の意味がなくなる ); - let buf_def = Def::new(buf_sig, body); + let block = body + .block + .into_iter() + .map(|ex| self.rec_desugar_lambda_pattern(ex)) + .collect(); + let buf_def = Def::new(buf_sig, DefBody::new(body.op, block, body.id)); new.push(Expr::Def(buf_def)); for VarRecordAttr { lhs, rhs } in pack.args.attrs.iter() { self.desugar_nested_var_pattern( @@ -433,6 +467,12 @@ impl Desugarer { } } VarPattern::Ident(_i) => { + let block = body + .block + .into_iter() + .map(|ex| self.rec_desugar_lambda_pattern(ex)) + .collect(); + let body = DefBody::new(body.op, block, body.id); let def = Def::new(Signature::Var(v), body); new.push(Expr::Def(def)); } @@ -444,13 +484,19 @@ impl Desugarer { }) => { let non_defaults = subr.params.non_defaults.iter_mut(); for param in non_defaults { - self.desugar_nd_param(param, &mut body); + self.desugar_nd_param(param, &mut body.block); } + let block = body + .block + .into_iter() + .map(|ex| self.rec_desugar_lambda_pattern(ex)) + .collect(); + let body = DefBody::new(body.op, block, body.id); let def = Def::new(Signature::Subr(subr), body); new.push(Expr::Def(def)); } other => { - new.push(other); + new.push(self.rec_desugar_lambda_pattern(other)); } } } @@ -622,7 +668,7 @@ impl Desugarer { /// ```erg /// f _: {1}, _: {2} = ... /// ``` - fn desugar_nd_param(&mut self, param: &mut NonDefaultParamSignature, body: &mut DefBody) { + fn desugar_nd_param(&mut self, param: &mut NonDefaultParamSignature, body: &mut Block) { let mut insertion_idx = 0; let line = param.ln_begin().unwrap(); match &mut param.pat { @@ -735,7 +781,7 @@ impl Desugarer { fn desugar_nested_param_pattern( &mut self, - new_sub_body: &mut DefBody, + new_body: &mut Block, sig: &mut NonDefaultParamSignature, buf_name: &str, buf_index: BufIndex, @@ -763,7 +809,7 @@ impl Desugarer { match &mut sig.pat { ParamPattern::Tuple(tup) => { let (buf_name, buf_sig) = self.gen_buf_nd_param(line); - new_sub_body.block.insert( + new_body.insert( insertion_idx, Expr::Def(Def::new( Signature::Var(VarSignature::new( @@ -777,7 +823,7 @@ impl Desugarer { let mut tys = vec![]; for (n, elem) in tup.elems.non_defaults.iter_mut().enumerate() { insertion_idx = self.desugar_nested_param_pattern( - new_sub_body, + new_body, elem, &buf_name, BufIndex::Tuple(n), @@ -800,7 +846,7 @@ impl Desugarer { } ParamPattern::Array(arr) => { let (buf_name, buf_sig) = self.gen_buf_nd_param(line); - new_sub_body.block.insert( + new_body.insert( insertion_idx, Expr::Def(Def::new( Signature::Var(VarSignature::new( @@ -813,7 +859,7 @@ impl Desugarer { insertion_idx += 1; for (n, elem) in arr.elems.non_defaults.iter_mut().enumerate() { insertion_idx = self.desugar_nested_param_pattern( - new_sub_body, + new_body, elem, &buf_name, BufIndex::Array(n), @@ -832,7 +878,7 @@ impl Desugarer { } ParamPattern::Record(rec) => { let (buf_name, buf_sig) = self.gen_buf_nd_param(line); - new_sub_body.block.insert( + new_body.insert( insertion_idx, Expr::Def(Def::new( Signature::Var(VarSignature::new( @@ -846,7 +892,7 @@ impl Desugarer { let mut tys = vec![]; for ParamRecordAttr { lhs, rhs } in rec.elems.iter_mut() { insertion_idx = self.desugar_nested_param_pattern( - new_sub_body, + new_body, rhs, &buf_name, BufIndex::Record(lhs), @@ -890,7 +936,7 @@ impl Desugarer { sig.t_spec.as_ref().map(|ts| ts.t_spec.clone()), ); let def = Def::new(Signature::Var(v), body); - new_sub_body.block.insert(insertion_idx, Expr::Def(def)); + new_body.insert(insertion_idx, Expr::Def(def)); insertion_idx += 1; insertion_idx }