From fc85265d9fbbdc1f153acb73ddb9a15ec9df8052 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 10 Apr 2023 22:26:46 +0900 Subject: [PATCH] fix: union types bug & multi-pattern def bug --- crates/erg_compiler/context/compare.rs | 71 +++++++---- crates/erg_compiler/context/eval.rs | 10 +- crates/erg_compiler/context/generalize.rs | 28 +++++ crates/erg_compiler/context/register.rs | 12 +- crates/erg_compiler/context/unify.rs | 35 ++++-- crates/erg_compiler/error/lower.rs | 1 + .../erg_compiler/lib/std/_erg_in_operator.py | 10 +- crates/erg_compiler/lint.rs | 45 ++++++- crates/erg_compiler/lower.rs | 17 +-- crates/erg_compiler/ty/free.rs | 19 ++- crates/erg_compiler/ty/mod.rs | 16 +++ crates/erg_parser/desugar.rs | 112 +++++++++++++----- tests/should_ok/rec.er | 9 +- tests/test.rs | 2 +- 14 files changed, 299 insertions(+), 88 deletions(-) diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index bcb5c77c..1d07365e 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -8,7 +8,7 @@ use erg_common::traits::StructuralEq; use erg_common::Str; use erg_common::{assume_unreachable, log}; -use crate::ty::constructors::{and, not, or, poly}; +use crate::ty::constructors::{and, bounded, not, or, poly}; use crate::ty::free::{Constraint, FreeKind}; use crate::ty::typaram::{OpKind, TyParam, TyParamOrdering}; use crate::ty::value::ValueObj; @@ -982,6 +982,18 @@ impl Context { } (Refinement(l), Refinement(r)) => Type::Refinement(self.union_refinement(l, r)), (Structural(l), Structural(r)) => self.union(l, r).structuralize(), + // Int..Obj or Nat..Obj ==> Int..Obj + // Str..Obj or Int..Obj ==> Str..Obj or Int..Obj + ( + Bounded { sub, sup }, + Bounded { + sub: sub2, + sup: sup2, + }, + ) => match (self.max(sub, sub2), self.min(sup, sup2)) { + (Some(sub), Some(sup)) => bounded(sub.clone(), sup.clone()), + _ => self.simple_union(lhs, rhs), + }, (t, Type::Never) | (Type::Never, t) => t.clone(), // Array({1, 2}, 2), Array({3, 4}, 2) ==> Array({1, 2, 3, 4}, 2) ( @@ -997,26 +1009,10 @@ impl Context { debug_assert_eq!(lps.len(), rps.len()); let mut unified_params = vec![]; for (lp, rp) in lps.iter().zip(rps.iter()) { - match (lp, rp) { - (TyParam::Value(ValueObj::Type(l)), TyParam::Value(ValueObj::Type(r))) => { - unified_params.push(TyParam::t(self.union(l.typ(), r.typ()))); - } - (TyParam::Value(ValueObj::Type(l)), TyParam::Type(r)) => { - unified_params.push(TyParam::t(self.union(l.typ(), r))); - } - (TyParam::Type(l), TyParam::Value(ValueObj::Type(r))) => { - unified_params.push(TyParam::t(self.union(l, r.typ()))); - } - (TyParam::Type(l), TyParam::Type(r)) => { - unified_params.push(TyParam::t(self.union(l, r))); - } - (_, _) => { - if self.eq_tp(lp, rp) { - unified_params.push(lp.clone()); - } else { - return self.simple_union(lhs, rhs); - } - } + if let Some(union) = self.union_tp(lp, rp) { + unified_params.push(union); + } else { + return self.simple_union(lhs, rhs); } } poly(ln, unified_params) @@ -1025,6 +1021,39 @@ impl Context { } } + fn union_tp(&self, lhs: &TyParam, rhs: &TyParam) -> Option { + match (lhs, rhs) { + (TyParam::Value(ValueObj::Type(l)), TyParam::Value(ValueObj::Type(r))) => { + Some(TyParam::t(self.union(l.typ(), r.typ()))) + } + (TyParam::Value(ValueObj::Type(l)), TyParam::Type(r)) => { + Some(TyParam::t(self.union(l.typ(), r))) + } + (TyParam::Type(l), TyParam::Value(ValueObj::Type(r))) => { + Some(TyParam::t(self.union(l, r.typ()))) + } + (TyParam::Type(l), TyParam::Type(r)) => Some(TyParam::t(self.union(l, r))), + (TyParam::Array(l), TyParam::Array(r)) => { + let mut tps = vec![]; + for (l, r) in l.iter().zip(r.iter()) { + if let Some(tp) = self.union_tp(l, r) { + tps.push(tp); + } else { + return None; + } + } + Some(TyParam::Array(tps)) + } + (_, _) => { + if self.eq_tp(lhs, rhs) { + Some(lhs.clone()) + } else { + None + } + } + } + } + fn simple_union(&self, lhs: &Type, rhs: &Type) -> Type { // `?T or ?U` will not be unified // `Set!(?T(<: Int), 3) or Set(?U(<: Nat), 3)` wii be unified to Set(?T, 3) diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index cd99584e..21180c4a 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -21,7 +21,7 @@ use crate::ty::constructors::{ array_t, dict_t, mono, poly, proj, proj_call, ref_, ref_mut, refinement, subr_t, tuple_t, v_enum, }; -use crate::ty::free::{Constraint, HasLevel}; +use crate::ty::free::{Constraint, FreeTyVar, HasLevel}; use crate::ty::typaram::{OpKind, TyParam}; use crate::ty::value::{GenTypeObj, TypeObj, ValueObj}; use crate::ty::{ConstSubr, HasType, Predicate, SubrKind, Type, UserConstSubr, ValueArgs}; @@ -1275,7 +1275,7 @@ impl Context { let t = self .convert_tp_into_type(params[0].clone()) .map_err(|_| ())?; - let len = enum_unwrap!(params[1], TyParam::Value:(ValueObj::Nat:(_))); + let TyParam::Value(ValueObj::Nat(len)) = params[1] else { unreachable!() }; Ok(vec![ValueObj::builtin_type(t); len as usize]) } _ => Err(()), @@ -1432,7 +1432,7 @@ impl Context { Ok(()) } TyParam::Type(gt) if gt.is_generalized() => { - let qt = enum_unwrap!(gt.as_ref(), Type::FreeVar); + let Ok(qt) = <&FreeTyVar>::try_from(gt.as_ref()) else { unreachable!() }; let Ok(st) = Type::try_from(stp) else { todo!(); }; if !st.is_generalized() { qt.undoable_link(&st); @@ -1442,7 +1442,7 @@ impl Context { TyParam::Type(qt) => { let Ok(st) = Type::try_from(stp) else { todo!(); }; let st = if st.typarams_len() != qt.typarams_len() { - let st = enum_unwrap!(st, Type::FreeVar); + let Ok(st) = <&FreeTyVar>::try_from(&st) else { unreachable!() }; st.get_sub().unwrap() } else { st @@ -1461,7 +1461,7 @@ impl Context { match tp { TyParam::FreeVar(fv) if fv.is_undoable_linked() => fv.undo(), TyParam::Type(t) if t.is_free_var() => { - let subst = enum_unwrap!(t.as_ref(), Type::FreeVar); + let Ok(subst) = <&FreeTyVar>::try_from(t.as_ref()) else { unreachable!() }; if subst.is_undoable_linked() { subst.undo(); } diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index 4da75600..4e6e03c9 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -1213,4 +1213,32 @@ impl Context { hir::Expr::Import(_) => unreachable!(), } } + + /// ```erg + /// squash_tyvar(?1 or ?2) == ?1(== ?2) + /// squash_tyvar(?T or ?U) == ?T or ?U + /// ``` + pub(crate) fn squash_tyvar(&self, typ: Type) -> Type { + match typ { + Type::Or(l, r) => { + let l = self.squash_tyvar(*l); + let r = self.squash_tyvar(*r); + if l.is_named_unbound_var() && r.is_named_unbound_var() { + self.union(&l, &r) + } else { + match (self.subtype_of(&l, &r), self.subtype_of(&r, &l)) { + (true, true) | (true, false) => { + let _ = self.sub_unify(&l, &r, &(), None); + } + (false, true) => { + let _ = self.sub_unify(&r, &l, &(), None); + } + _ => {} + } + self.union(&l, &r) + } + } + other => other, + } + } } diff --git a/crates/erg_compiler/context/register.rs b/crates/erg_compiler/context/register.rs index 63934b65..b6948f81 100644 --- a/crates/erg_compiler/context/register.rs +++ b/crates/erg_compiler/context/register.rs @@ -656,8 +656,16 @@ impl Context { TyCheckErrors::new( errs.into_iter() .map(|e| { - let expect = self.readable_type(spec_ret_t.clone()); - let found = self.readable_type(body_t.clone()); + let expect = if cfg!(feature = "debug") { + spec_ret_t.clone() + } else { + self.readable_type(spec_ret_t.clone()) + }; + let found = if cfg!(feature = "debug") { + body_t.clone() + } else { + self.readable_type(body_t.clone()) + }; TyCheckError::return_type_error( self.cfg.input.clone(), line!() as usize, diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 9c09cadf..0d223af3 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -30,6 +30,8 @@ impl Context { /// occur(X -> ?T, X -> ?T) ==> OK /// occur(?T, ?T -> X) ==> Error /// occur(?T, Option(?T)) ==> Error + /// occur(?T or ?U, ?T) ==> OK + /// occur(?T(<: Str) or ?U(<: Int), ?T(<: Str)) ==> Error /// occur(?T, ?T.Output) ==> OK pub(crate) fn occur( &self, @@ -118,6 +120,10 @@ impl Context { } Ok(()) } + (Or(l, r), Or(l2, r2)) | (And(l, r), And(l2, r2)) => { + self.occur(l, l2, loc)?; + self.occur(r, r2, loc) + } (lhs, Or(l, r)) | (lhs, And(l, r)) => { self.occur_inner(lhs, l, loc)?; self.occur_inner(lhs, r, loc) @@ -787,14 +793,29 @@ impl Context { self.caused_by(), ))); }; - if sub_fv.level().unwrap_or(GENERIC_LEVEL) - <= sup_fv.level().unwrap_or(GENERIC_LEVEL) + match sub_fv + .level() + .unwrap_or(GENERIC_LEVEL) + .cmp(&sup_fv.level().unwrap_or(GENERIC_LEVEL)) { - sub_fv.update_constraint(new_constraint, false); - sup_fv.link(maybe_sub); - } else { - sup_fv.update_constraint(new_constraint, false); - sub_fv.link(maybe_sup); + std::cmp::Ordering::Less => { + sub_fv.update_constraint(new_constraint, false); + sup_fv.link(maybe_sub); + } + std::cmp::Ordering::Greater => { + sup_fv.update_constraint(new_constraint, false); + sub_fv.link(maybe_sup); + } + std::cmp::Ordering::Equal => { + // choose named one + if sup_fv.is_named_unbound() { + sup_fv.update_constraint(new_constraint, false); + sub_fv.link(maybe_sup); + } else { + sub_fv.update_constraint(new_constraint, false); + sup_fv.link(maybe_sub); + } + } } Ok(()) } diff --git a/crates/erg_compiler/error/lower.rs b/crates/erg_compiler/error/lower.rs index fba11da9..94a97423 100644 --- a/crates/erg_compiler/error/lower.rs +++ b/crates/erg_compiler/error/lower.rs @@ -1024,6 +1024,7 @@ impl LowerWarning { fn_name: &str, typ: &Type, ) -> Self { + let fn_name = fn_name.with_color(Color::Yellow); let hint = switch_lang!( "japanese" => format!("`{fn_name}(...): {typ} = ...`など明示的に戻り値型を指定してください"), "simplified_chinese" => format!("请明确指定函数{fn_name}的返回类型,例如`{fn_name}(...): {typ} = ...`"), diff --git a/crates/erg_compiler/lib/std/_erg_in_operator.py b/crates/erg_compiler/lib/std/_erg_in_operator.py index a3ab5fb8..c03dcd46 100644 --- a/crates/erg_compiler/lib/std/_erg_in_operator.py +++ b/crates/erg_compiler/lib/std/_erg_in_operator.py @@ -10,14 +10,18 @@ def in_operator(elem, y): return True # TODO: trait check return False - elif issubclass(type(y), list) and ( - type(y[0]) == type or issubclass(type(y[0]), Range) + elif isinstance(y, list) and ( + type(y[0]) == type or isinstance(y[0], Range) ): # FIXME: type_check = in_operator(elem[0], y[0]) len_check = len(elem) == len(y) return type_check and len_check - elif issubclass(type(y), dict) and issubclass(type(next(iter(y.keys()))), type): + elif isinstance(y, tuple): + type_check = all(map(lambda x: in_operator(x[0], x[1]), zip(elem, y))) + len_check = len(elem) == len(y) + return type_check and len_check + elif isinstance(y, dict) and isinstance(next(iter(y.keys())), type): # TODO: type_check = True # in_operator(x[next(iter(x.keys()))], next(iter(y.keys()))) len_check = len(elem) >= len(y) diff --git a/crates/erg_compiler/lint.rs b/crates/erg_compiler/lint.rs index 67e0bb34..06488191 100644 --- a/crates/erg_compiler/lint.rs +++ b/crates/erg_compiler/lint.rs @@ -16,7 +16,7 @@ use crate::ty::{HasType, Type, ValueObj, VisibilityModifier}; use crate::error::{ CompileErrors, LowerError, LowerResult, LowerWarning, LowerWarnings, SingleLowerResult, }; -use crate::hir::{self, Expr, HIR}; +use crate::hir::{self, Expr, Signature, HIR}; use crate::lower::ASTLowerer; use crate::varinfo::VarInfo; @@ -279,4 +279,47 @@ impl ASTLowerer { self.check_doc_comments(&hir); self.module.context.pop(); } + + pub(crate) fn warn_implicit_union(&mut self, hir: &HIR) { + for chunk in hir.module.iter() { + self.warn_implicit_union_chunk(chunk); + } + } + + fn warn_implicit_union_chunk(&mut self, chunk: &Expr) { + match chunk { + Expr::ClassDef(class_def) => { + for chunk in class_def.methods.iter() { + self.warn_implicit_union_chunk(chunk); + } + } + Expr::PatchDef(patch_def) => { + for chunk in patch_def.methods.iter() { + self.warn_implicit_union_chunk(chunk); + } + } + Expr::Def(def) => { + if let Signature::Subr(subr) = &def.sig { + let return_t = subr.ref_t().return_t().unwrap(); + if return_t.union_pair().is_some() && subr.return_t_spec.is_none() { + let typ = if cfg!(feature = "debug") { + return_t.clone() + } else { + self.module.context.readable_type(return_t.clone()) + }; + let warn = LowerWarning::union_return_type_warning( + self.input().clone(), + line!() as usize, + subr.loc(), + self.module.context.caused_by(), + subr.ident.inspect(), + &typ, + ); + self.warns.push(warn); + } + } + } + _ => {} + } + } } diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index c9889053..d77e5022 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -1582,25 +1582,13 @@ impl ASTLowerer { } match self.lower_block(body.block) { Ok(block) => { - let found_body_t = block.ref_t(); + let found_body_t = self.module.context.squash_tyvar(block.t()); let vi = self.module.context.outer.as_mut().unwrap().assign_subr( &sig, body.id, - found_body_t, + &found_body_t, block.last().unwrap(), )?; - let return_t = vi.t.return_t().unwrap(); - if return_t.union_pair().is_some() && sig.return_t_spec.is_none() { - let warn = LowerWarning::union_return_type_warning( - self.input().clone(), - line!() as usize, - sig.loc(), - self.module.context.caused_by(), - sig.ident.inspect(), - &self.module.context.readable_type(return_t.clone()), - ); - self.warns.push(warn); - } let ident = hir::Identifier::new(sig.ident, None, vi); let sig = hir::SubrSignature::new(ident, sig.bounds, params, sig.return_t_spec); @@ -2519,6 +2507,7 @@ impl ASTLowerer { return Err(self.return_incomplete_artifact(hir)); } }; + self.warn_implicit_union(&hir); self.warn_unused_expr(&hir.module, mode); self.warn_unused_vars(mode); self.check_doc_comments(&hir); diff --git a/crates/erg_compiler/ty/free.rs b/crates/erg_compiler/ty/free.rs index d49154a3..f17c95d5 100644 --- a/crates/erg_compiler/ty/free.rs +++ b/crates/erg_compiler/ty/free.rs @@ -462,6 +462,14 @@ impl FreeKind { } } } + + pub const fn is_named_unbound(&self) -> bool { + matches!(self, Self::NamedUnbound { .. }) + } + + pub const fn is_undoable_linked(&self) -> bool { + matches!(self, Self::UndoableLinked { .. }) + } } #[derive(Debug, Clone)] @@ -753,14 +761,15 @@ impl Free { } pub fn is_linked(&self) -> bool { - matches!( - &*self.borrow(), - FreeKind::Linked(_) | FreeKind::UndoableLinked { .. } - ) + self.borrow().linked().is_some() } pub fn is_undoable_linked(&self) -> bool { - matches!(&*self.borrow(), FreeKind::UndoableLinked { .. }) + self.borrow().is_undoable_linked() + } + + pub fn is_named_unbound(&self) -> bool { + self.borrow().is_named_unbound() } pub fn unsafe_crack(&self) -> &T { diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index 7e182fe0..d30e8c39 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -2081,6 +2081,22 @@ impl Type { matches!(self, Self::FreeVar(fv) if fv.is_unbound() || fv.crack().is_unbound_var()) } + pub fn is_named_unbound_var(&self) -> bool { + matches!(self, Self::FreeVar(fv) if fv.is_named_unbound() || (fv.is_linked() && fv.crack().is_named_unbound_var())) + } + + pub fn is_totally_unbound(&self) -> bool { + match self { + Self::FreeVar(fv) if fv.is_unbound() => true, + Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_totally_unbound(), + Self::Or(t1, t2) | Self::And(t1, t2) => { + t1.is_totally_unbound() && t2.is_totally_unbound() + } + Self::Not(t) => t.is_totally_unbound(), + _ => false, + } + } + /// See also: `is_monomorphized` pub fn is_monomorphic(&self) -> bool { matches!(self.typarams_len(), Some(0) | None) diff --git a/crates/erg_parser/desugar.rs b/crates/erg_parser/desugar.rs index 7e2bfe8c..d8bb9d00 100644 --- a/crates/erg_parser/desugar.rs +++ b/crates/erg_parser/desugar.rs @@ -13,10 +13,10 @@ use crate::ast::{ ClassAttr, ClassAttrs, ClassDef, ConstExpr, DataPack, Def, DefBody, DefId, Dict, Dummy, Expr, Identifier, KeyValue, KwArg, Lambda, LambdaSignature, Literal, Methods, MixedRecord, Module, NonDefaultParamSignature, NormalArray, NormalDict, NormalRecord, NormalSet, NormalTuple, - ParamPattern, ParamRecordAttr, Params, PatchDef, PosArg, ReDef, Record, RecordAttrOrIdent, - RecordAttrs, Set as astSet, SetWithLength, Signature, SubrSignature, Tuple, TupleTypeSpec, - TypeAppArgs, TypeAppArgsKind, TypeBoundSpecs, TypeSpec, TypeSpecWithOp, UnaryOp, VarName, - VarPattern, VarRecordAttr, VarSignature, VisModifierSpec, + ParamPattern, ParamRecordAttr, ParamTuplePattern, Params, PatchDef, PosArg, ReDef, Record, + RecordAttrOrIdent, RecordAttrs, Set as astSet, SetWithLength, Signature, SubrSignature, Tuple, + TupleTypeSpec, TypeAppArgs, TypeAppArgsKind, TypeBoundSpecs, TypeSpec, TypeSpecWithOp, UnaryOp, + VarName, VarPattern, VarRecordAttr, VarSignature, VisModifierSpec, }; use crate::token::{Token, TokenKind, COLON, DOT}; @@ -145,7 +145,7 @@ impl Desugarer { }, Expr::DataPack(pack) => { let class = desugar(*pack.class); - let args = enum_unwrap!(desugar(Expr::Record(pack.args)), Expr::Record); + let Expr::Record(args) = desugar(Expr::Record(pack.args)) else { unreachable!() }; Expr::DataPack(DataPack::new(class, pack.connector, args)) } Expr::Array(array) => match array { @@ -249,7 +249,7 @@ impl Desugarer { Expr::Def(Def::new(def.sig, body)) } Expr::ClassDef(class_def) => { - let def = enum_unwrap!(desugar(Expr::Def(class_def.def)), Expr::Def); + let Expr::Def(def) = desugar(Expr::Def(class_def.def)) else { unreachable!() }; let methods = class_def .methods_list .into_iter() @@ -258,7 +258,7 @@ impl Desugarer { Expr::ClassDef(ClassDef::new(def, methods)) } Expr::PatchDef(class_def) => { - let def = enum_unwrap!(desugar(Expr::Def(class_def.def)), Expr::Def); + let Expr::Def(def) = desugar(Expr::Def(class_def.def)) else { unreachable!() }; let methods = class_def .methods_list .into_iter() @@ -343,7 +343,7 @@ impl Desugarer { if let Some(Expr::Def(previous)) = new.last() { if previous.is_subr() && previous.sig.name_as_str() == def.sig.name_as_str() { - let previous = enum_unwrap!(new.pop().unwrap(), Expr::Def); + let Some(Expr::Def(previous)) = new.pop() else { unreachable!() }; let name = def.sig.ident().unwrap().clone(); let id = def.body.id; let op = def.body.op.clone(); @@ -354,17 +354,46 @@ impl Desugarer { } else { self.gen_match_call(previous, def) }; - let param_name = enum_unwrap!(&call.args.pos_args().iter().next().unwrap().expr, Expr::Accessor:(Accessor::Ident:(_))).inspect(); - // FIXME: multiple params - let param = VarName::new(Token::new( - TokenKind::Symbol, - param_name, - name.ln_begin().unwrap_or(1), - name.col_end().unwrap_or(0) + 1, // HACK: `(name) %x = ...`という形を想定 - )); - let param = - NonDefaultParamSignature::new(ParamPattern::VarName(param), None); - let params = Params::single(param); + let params = match &call.args.pos_args().iter().next().unwrap().expr { + Expr::Tuple(Tuple::Normal(tup)) => { + let mut params = vec![]; + for arg in tup.elems.pos_args().iter() { + match &arg.expr { + Expr::Accessor(Accessor::Ident(ident)) => { + let param_name = ident.inspect(); + let param = VarName::new(Token::new( + TokenKind::Symbol, + param_name, + name.ln_begin().unwrap_or(1), + name.col_end().unwrap_or(0) + 1, + )); + let param = NonDefaultParamSignature::new( + ParamPattern::VarName(param), + None, + ); + params.push(param); + } + _ => unreachable!(), + } + } + Params::new(params, None, vec![], None) + } + Expr::Accessor(Accessor::Ident(ident)) => { + let param_name = ident.inspect(); + let param = VarName::new(Token::new( + TokenKind::Symbol, + param_name, + name.ln_begin().unwrap_or(1), + name.col_end().unwrap_or(0) + 1, // HACK: `(name) %x = ...`という形を想定 + )); + let param = NonDefaultParamSignature::new( + ParamPattern::VarName(param), + None, + ); + Params::single(param) + } + _ => unreachable!(), + }; let sig = Signature::Subr(SubrSignature::new( set! {}, name, @@ -392,8 +421,8 @@ impl Desugarer { fn add_arg_to_match_call(&self, mut previous: Def, def: Def) -> (Call, Option) { let op = Token::from_str(TokenKind::FuncArrow, "->"); - let mut call = enum_unwrap!(previous.body.block.remove(0), Expr::Call); - let sig = enum_unwrap!(def.sig, Signature::Subr); + let Expr::Call(mut call) = previous.body.block.remove(0) else { unreachable!() }; + let Signature::Subr(sig) = def.sig else { unreachable!() }; let return_t_spec = sig.return_t_spec; let first_arg = sig.params.non_defaults.first().unwrap(); // 最後の定義の引数名を関数全体の引数名にする @@ -406,7 +435,14 @@ impl Desugarer { )); call.args.insert_pos(0, arg); } - let sig = LambdaSignature::new(sig.params, return_t_spec.clone(), sig.bounds); + // f(x, y, z) = ... => match x, ((x, y, z),) -> ... + let params = if sig.params.len() == 1 { + sig.params + } else { + let pat = ParamPattern::Tuple(ParamTuplePattern::new(sig.params)); + Params::single(NonDefaultParamSignature::new(pat, None)) + }; + let sig = LambdaSignature::new(params, return_t_spec.clone(), sig.bounds); let new_branch = Lambda::new(sig, op, def.body.block, def.body.id); call.args.push_pos(PosArg::new(Expr::Lambda(new_branch))); (call, return_t_spec) @@ -415,17 +451,39 @@ impl Desugarer { // TODO: procedural match fn gen_match_call(&self, previous: Def, def: Def) -> (Call, Option) { let op = Token::from_str(TokenKind::FuncArrow, "->"); - let sig = enum_unwrap!(previous.sig, Signature::Subr); + let Signature::Subr(prev_sig) = previous.sig else { unreachable!() }; + let params_len = prev_sig.params.len(); + let params = if params_len == 1 { + prev_sig.params + } else { + let pat = ParamPattern::Tuple(ParamTuplePattern::new(prev_sig.params)); + Params::single(NonDefaultParamSignature::new(pat, None)) + }; let match_symbol = Expr::static_local("match"); - let sig = LambdaSignature::new(sig.params, sig.return_t_spec, sig.bounds); + let sig = LambdaSignature::new(params, prev_sig.return_t_spec, prev_sig.bounds); let first_branch = Lambda::new(sig, op.clone(), previous.body.block, previous.body.id); - let sig = enum_unwrap!(def.sig, Signature::Subr); + let Signature::Subr(sig) = def.sig else { unreachable!() }; + let params = if sig.params.len() == 1 { + sig.params + } else { + let pat = ParamPattern::Tuple(ParamTuplePattern::new(sig.params)); + Params::single(NonDefaultParamSignature::new(pat, None)) + }; let return_t_spec = sig.return_t_spec; - let sig = LambdaSignature::new(sig.params, return_t_spec.clone(), sig.bounds); + let sig = LambdaSignature::new(params, return_t_spec.clone(), sig.bounds); let second_branch = Lambda::new(sig, op, def.body.block, def.body.id); + let first_arg = if params_len == 1 { + Expr::dummy_local(&fresh_varname()) + } else { + let args = (0..params_len).map(|_| PosArg::new(Expr::dummy_local(&fresh_varname()))); + Expr::Tuple(Tuple::Normal(NormalTuple::new(Args::pos_only( + args.collect(), + None, + )))) + }; let args = Args::pos_only( vec![ - PosArg::new(Expr::dummy_local("_")), // dummy argument, will be removed in line 56 + PosArg::new(first_arg), // dummy argument, will be removed in line 56 PosArg::new(Expr::Lambda(first_branch)), PosArg::new(Expr::Lambda(second_branch)), ], diff --git a/tests/should_ok/rec.er b/tests/should_ok/rec.er index 35097be6..03629e1d 100644 --- a/tests/should_ok/rec.er +++ b/tests/should_ok/rec.er @@ -9,5 +9,10 @@ stop_or_call n, f: (Nat -> Nat), g: (Nat -> Nat) = fact(n: Nat): Nat = stop_or_call n, fact, (r, ) -> r * n -print! fact -print! fact 5 +assert fact(5) == 120 + +iterate(_, 0, x) = x +iterate(f, n: Int, x) = iterate f, n-1, f x + +assert iterate((x -> x + 1), 5, 0) == 5 +assert iterate((x -> x + "a"), 5, "b") == "baaaaa" diff --git a/tests/test.rs b/tests/test.rs index 40ba3a5d..93043557 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -154,7 +154,7 @@ fn exec_raw_ident() -> Result<(), ()> { #[test] fn exec_rec() -> Result<(), ()> { - expect_success("tests/should_ok/rec.er", 1) + expect_success("tests/should_ok/rec.er", 0) } #[test]