diff --git a/compiler/erg_common/ty.rs b/compiler/erg_common/ty.rs index 64de79f8..d17a9a36 100644 --- a/compiler/erg_common/ty.rs +++ b/compiler/erg_common/ty.rs @@ -1448,12 +1448,20 @@ pub struct SubrType { impl fmt::Display for SubrType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut default_params = String::new(); + for default_param in self.default_params.iter() { + default_params.push_str(&format!( + "{} |= {}", + default_param.name.as_ref().unwrap(), + default_param.ty + )); + } write!( f, "{}({}, {}) {} {}", self.kind.prefix(), fmt_vec(&self.non_default_params), - fmt_vec(&self.default_params), + default_params, self.kind.arrow(), self.return_t, ) diff --git a/compiler/erg_compiler/context.rs b/compiler/erg_compiler/context.rs index 07680921..c4a80853 100644 --- a/compiler/erg_compiler/context.rs +++ b/compiler/erg_compiler/context.rs @@ -1656,17 +1656,26 @@ impl Context { match t { // ?T(:> Nat, <: Int)[n] => Nat (self.level <= n) // ?T(:> Nat, <: Sub ?U(:> {1}))[n] => Nat + // ?T(:> Never, <: Nat)[n] => Nat Type::FreeVar(fv) if fv.constraint_is_sandwiched() => { let constraint = fv.crack_constraint(); let (sub, sup) = constraint.sub_sup_type().unwrap(); if self.rec_full_same_type_of(sub, sup) { self.unify(sub, sub, None, None)?; - let t = sub.clone(); + let t = if sub.rec_eq(&Never) { + sup.clone() + } else { + sub.clone() + }; drop(constraint); fv.link(&t); self.deref_tyvar(Type::FreeVar(fv)) } else if self.level == 0 || self.level <= fv.level().unwrap() { - let t = sub.clone(); + let t = if sub.rec_eq(&Never) { + sup.clone() + } else { + sub.clone() + }; drop(constraint); fv.link(&t); self.deref_tyvar(Type::FreeVar(fv)) @@ -1779,6 +1788,9 @@ impl Context { for arg in call.args.pos_args.iter_mut() { self.deref_expr_t(&mut arg.expr)?; } + for arg in call.args.kw_args.iter_mut() { + self.deref_expr_t(&mut arg.expr)?; + } Ok(()) } hir::Expr::Decl(decl) => { @@ -2841,6 +2853,7 @@ impl Context { } for pos_arg in pos_args.iter().skip(1) { let t = pos_arg.expr.ref_t(); + // Allow only anonymous functions to be passed as match arguments (for aesthetic reasons) if !matches!(&pos_arg.expr, hir::Expr::Lambda(_)) { return Err(TyCheckError::type_mismatch_error( line!() as usize, @@ -2852,37 +2865,38 @@ impl Context { )); } } - let expr_t = pos_args[0].expr.ref_t(); + let match_target_expr_t = pos_args[0].expr.ref_t(); // Never or T => T let mut union_pat_t = Type::Never; - for (i, a) in pos_args.iter().skip(1).enumerate() { - let lambda = erg_common::enum_unwrap!(&a.expr, hir::Expr::Lambda); + for (i, pos_arg) in pos_args.iter().skip(1).enumerate() { + let lambda = erg_common::enum_unwrap!(&pos_arg.expr, hir::Expr::Lambda); if !lambda.params.defaults.is_empty() { todo!() } + // TODO: If the first argument of the match is a tuple? if lambda.params.len() != 1 { return Err(TyCheckError::argument_error( line!() as usize, pos_args[i + 1].loc(), self.caused_by(), 1, - pos_args[i + 1].expr.ref_t().typarams_len(), + pos_args[i + 1].expr.signature_t().unwrap().typarams_len(), )); } let rhs = self.instantiate_param_sig_t(&lambda.params.non_defaults[0], None, Normal)?; union_pat_t = self.union(&union_pat_t, &rhs); } // NG: expr_t: Nat, union_pat_t: {1, 2} - // OK: expr_t: Int, union_pat_t: {1} | 'T - if expr_t.has_no_unbound_var() - && self.formal_supertype_of(expr_t, &union_pat_t, None, None) - && !self.formal_supertype_of(&union_pat_t, expr_t, None, None) + // OK: expr_t: Int, union_pat_t: {1} or 'T + if self + .sub_unify(match_target_expr_t, &union_pat_t, None, None) + .is_err() { return Err(TyCheckError::match_error( line!() as usize, pos_args[0].loc(), self.caused_by(), - expr_t, + match_target_expr_t, )); } let branch_ts = pos_args @@ -2894,12 +2908,8 @@ impl Context { for arg_t in branch_ts.iter().skip(1) { return_t = self.union(&return_t, arg_t.ty.return_t().unwrap()); } - let expr_t = if expr_t.has_unbound_var() { - union_pat_t - } else { - expr_t.clone() - }; - let param_ts = [vec![ParamTy::anonymous(expr_t)], branch_ts.to_vec()].concat(); + let param_ty = ParamTy::anonymous(match_target_expr_t.clone()); + let param_ts = [vec![param_ty], branch_ts.to_vec()].concat(); let t = Type::func(param_ts, vec![], return_t); Ok(t) } @@ -3163,14 +3173,18 @@ impl Context { return true; } } - for (patch_name, sub, sup) in self.glue_patch_and_types.iter() { + for (patch_name, sub_type, sup_trait) in self.glue_patch_and_types.iter() { let patch = self .rec_get_patch(patch_name) .unwrap_or_else(|| panic!("{patch_name} not found")); let bounds = patch.type_params_bounds(); let variance = patch.type_params_variance(); - if self.formal_supertype_of(sub, rhs, Some(&bounds), Some(&variance)) - && self.formal_supertype_of(sup, lhs, Some(&bounds), Some(&variance)) + // e.g. + // P = Patch X, Impl: Ord + // Rhs <: X => Rhs <: Ord + // Ord <: Lhs => Rhs <: Ord <: Lhs + if self.formal_supertype_of(sub_type, rhs, Some(&bounds), Some(&variance)) + && self.formal_subtype_of(sup_trait, lhs, Some(&bounds), Some(&variance)) { return true; }