diff --git a/compiler/erg_compiler/context/inquire.rs b/compiler/erg_compiler/context/inquire.rs index 05407373..2439b732 100644 --- a/compiler/erg_compiler/context/inquire.rs +++ b/compiler/erg_compiler/context/inquire.rs @@ -17,7 +17,7 @@ use ast::VarName; use erg_parser::ast::{self, Identifier}; use erg_parser::token::Token; -use crate::ty::constructors::{anon, free_var, func, mono, poly, proj, subr_t}; +use crate::ty::constructors::{anon, free_var, func, mono, poly, proc, proj, subr_t}; use crate::ty::free::Constraint; use crate::ty::typaram::TyParam; use crate::ty::value::{GenTypeObj, TypeObj, ValueObj}; @@ -234,6 +234,7 @@ impl Context { fn get_match_call_t( &self, + kind: SubrKind, pos_args: &[hir::PosArg], kw_args: &[hir::KwArg], ) -> TyCheckResult { @@ -310,7 +311,11 @@ impl Context { } let param_ty = ParamTy::anonymous(match_target_expr_t.clone()); let param_ts = [vec![param_ty], branch_ts.to_vec()].concat(); - let t = func(param_ts, None, vec![], return_t); + let t = if kind.is_func() { + func(param_ts, None, vec![], return_t) + } else { + proc(param_ts, None, vec![], return_t) + }; Ok(VarInfo { t, ..VarInfo::default() @@ -1162,7 +1167,10 @@ impl Context { #[allow(clippy::single_match)] match &local.inspect()[..] { "match" => { - return self.get_match_call_t(pos_args, kw_args); + return self.get_match_call_t(SubrKind::Func, pos_args, kw_args); + } + "match!" => { + return self.get_match_call_t(SubrKind::Proc, pos_args, kw_args); } /*"import" | "pyimport" | "py" => { return self.get_import_call_t(pos_args, kw_args); diff --git a/compiler/erg_compiler/lower.rs b/compiler/erg_compiler/lower.rs index 7b08c5a6..decb15ee 100644 --- a/compiler/erg_compiler/lower.rs +++ b/compiler/erg_compiler/lower.rs @@ -590,7 +590,9 @@ impl ASTLowerer { fn lower_ident(&self, ident: ast::Identifier) -> LowerResult { // `match` is an untypable special form // `match`は型付け不可能な特殊形式 - let (vi, __name__) = if ident.vis().is_private() && &ident.inspect()[..] == "match" { + let (vi, __name__) = if ident.vis().is_private() + && (&ident.inspect()[..] == "match" || &ident.inspect()[..] == "match!") + { (VarInfo::default(), None) } else { ( diff --git a/compiler/erg_compiler/ty/mod.rs b/compiler/erg_compiler/ty/mod.rs index 8256b6d5..835859a3 100644 --- a/compiler/erg_compiler/ty/mod.rs +++ b/compiler/erg_compiler/ty/mod.rs @@ -1068,6 +1068,13 @@ impl SubrKind { Self::Proc => Str::ever("=>"), } } + + pub fn is_func(&self) -> bool { + matches!(self, Self::Func) + } + pub fn is_proc(&self) -> bool { + matches!(self, Self::Proc) + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] diff --git a/compiler/erg_parser/ast.rs b/compiler/erg_parser/ast.rs index db7aa6fb..4c533e3a 100644 --- a/compiler/erg_parser/ast.rs +++ b/compiler/erg_parser/ast.rs @@ -988,7 +988,7 @@ impl Call { pub fn is_match(&self) -> bool { self.obj .get_name() - .map(|s| &s[..] == "match") + .map(|s| &s[..] == "match" || &s[..] == "match!") .unwrap_or(false) } diff --git a/compiler/erg_parser/desugar.rs b/compiler/erg_parser/desugar.rs index 47c730a3..2824dc6e 100644 --- a/compiler/erg_parser/desugar.rs +++ b/compiler/erg_parser/desugar.rs @@ -300,6 +300,7 @@ impl Desugarer { (call, return_t_spec) } + // TODO: procedural match fn gen_match_call(&self, previous: Def, def: Def) -> (Call, Option) { let op = Token::from_str(TokenKind::FuncArrow, "->"); let sig = enum_unwrap!(previous.sig, Signature::Subr);