diff --git a/crates/erg_common/triple.rs b/crates/erg_common/triple.rs index a3626d3d..0d25a185 100644 --- a/crates/erg_common/triple.rs +++ b/crates/erg_common/triple.rs @@ -77,6 +77,14 @@ impl Triple { } } + pub fn or_else_triple(self, f: impl FnOnce() -> Triple) -> Triple { + match self { + Triple::None => f(), + Triple::Ok(ok) => Triple::Ok(ok), + Triple::Err(err) => Triple::Err(err), + } + } + pub fn unwrap_or(self, default: T) -> T { match self { Triple::None => default, diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index b7e78f57..12e5b7a7 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -445,6 +445,23 @@ impl Context { && var_params_judge && default_check() // contravariant } + // {Int} <: Obj -> Int + (Subr(_) | Quantified(_), Refinement(refine)) + if rhs.singleton_value().is_some() && self.subtype_of(&refine.t, &ClassType) => + { + let Ok(typ) = self.convert_tp_into_type(rhs.singleton_value().unwrap().clone()) + else { + return false; + }; + let Some(ctx) = self.get_nominal_type_ctx(&typ) else { + return false; + }; + if let Some((_, __call__)) = ctx.get_class_attr("__call__") { + self.supertype_of(lhs, &__call__.t) + } else { + false + } + } // ?T(<: Int) :> ?U(:> Nat) // ?T(<: Int) :> ?U(:> Int) // ?T(<: Nat) !:> ?U(:> Int) (if the upper bound of LHS is smaller than the lower bound of RHS, LHS cannot not be a supertype) diff --git a/crates/erg_compiler/context/initialize/procs.rs b/crates/erg_compiler/context/initialize/procs.rs index ba32a51e..2dc9e80f 100644 --- a/crates/erg_compiler/context/initialize/procs.rs +++ b/crates/erg_compiler/context/initialize/procs.rs @@ -60,7 +60,7 @@ impl Context { ) .quantify(); let t_proc_ret = if PYTHON_MODE { Obj } else { NoneType }; - let t_for = nd_proc( + let t_for = proc( vec![ kw("iterable", poly("Iterable", vec![ty_tp(T.clone())])), kw( @@ -69,6 +69,8 @@ impl Context { ), ], None, + vec![kw("else!", nd_proc(vec![], None, t_proc_ret.clone()))], + None, NoneType, ) .quantify(); @@ -90,12 +92,14 @@ impl Context { // not Bool! type because `cond` may be the result of evaluation of a mutable object's method returns Bool. nd_proc(vec![], None, Bool) }; - let t_while = nd_proc( + let t_while = proc( vec![ kw("cond!", t_cond), - kw("proc!", nd_proc(vec![], None, t_proc_ret)), + kw("proc!", nd_proc(vec![], None, t_proc_ret.clone())), ], None, + vec![kw("else!", nd_proc(vec![], None, t_proc_ret.clone()))], + None, NoneType, ); let P = mono_q("P", subtypeof(mono("PathLike"))); diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 628e3303..923b7efc 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -750,6 +750,27 @@ impl Context { Triple::None } + pub(crate) fn rec_get_param_or_decl_info(&self, name: &str) -> Option { + if let Some(vi) = self + .params + .iter() + .find(|(var_name, _)| var_name.as_ref().is_some_and(|n| n.inspect() == name)) + .map(|(_, vi)| vi) + .or_else(|| self.decls.get(name)) + { + return Some(vi.clone()); + } + for method_ctx in self.methods_list.iter() { + if let Some(vi) = method_ctx.rec_get_param_or_decl_info(name) { + return Some(vi); + } + } + if let Some(parent) = self.get_outer_scope().or_else(|| self.get_builtins()) { + return parent.rec_get_param_or_decl_info(name); + } + None + } + pub(crate) fn get_attr_info( &self, obj: &hir::Expr, diff --git a/crates/erg_compiler/context/register.rs b/crates/erg_compiler/context/register.rs index bc38f68e..c5d70a7d 100644 --- a/crates/erg_compiler/context/register.rs +++ b/crates/erg_compiler/context/register.rs @@ -130,7 +130,7 @@ impl Context { let mut errs = TyCheckErrors::empty(); let muty = Mutability::from(&sig.inspect().unwrap_or(UBAR)[..]); let ident = match &sig.pat { - ast::VarPattern::Ident(ident) => ident, + ast::VarPattern::Ident(ident) | ast::VarPattern::Phi(ident) => ident, ast::VarPattern::Discard(_) | ast::VarPattern::Glob(_) => { return Ok(()); } @@ -287,7 +287,7 @@ impl Context { None }; let ident = match &sig.pat { - ast::VarPattern::Ident(ident) => ident, + ast::VarPattern::Ident(ident) | ast::VarPattern::Phi(ident) => ident, ast::VarPattern::Discard(_) => { return Ok(VarInfo { t: body_t.clone(), diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 9ac12c54..c91e16b8 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -1874,6 +1874,8 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } self.sub_unify_pred(&sub.pred, &supe.pred)?; } + // {Int} <: Obj -> Int + (Refinement(_), Subr(_) | Quantified(_)) if maybe_sub.singleton_value().is_some() => {} // {I: Int | I >= 1} <: Nat == {I: Int | I >= 0} (Refinement(_), sup) => { let sup = sup.clone().into_refinement(); diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 7f62796d..d33c6444 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -2253,6 +2253,11 @@ impl GenericASTLowerer { errors.extend(errs); } let outer = self.module.context.outer.as_ref().unwrap(); + let existing_vi = sig + .ident() + .and_then(|ident| outer.get_current_scope_var(&ident.name)) + .cloned(); + let existing_t = existing_vi.as_ref().map(|vi| vi.t.clone()); let expect_body_t = sig .t_spec .as_ref() @@ -2269,14 +2274,14 @@ impl GenericASTLowerer { }) .or_else(|| { sig.ident() - .and_then(|ident| outer.get_current_scope_var(&ident.name)) + .and_then(|ident| outer.rec_get_param_or_decl_info(ident.inspect())) .map(|vi| vi.t.clone()) }); match self.lower_block(body.block, expect_body.or(expect_body_t.as_ref())) { Ok(block) => { let found_body_t = block.ref_t(); let ident = match &sig.pat { - ast::VarPattern::Ident(ident) => ident.clone(), + ast::VarPattern::Ident(ident) | ast::VarPattern::Phi(ident) => ident.clone(), ast::VarPattern::Discard(token) => { ast::Identifier::private_from_token(token.clone()) } @@ -2291,6 +2296,7 @@ impl GenericASTLowerer { .map_err(|errs| (None, errors.concat(errs))); } }; + let mut no_reassign = false; if let Some(expect_body_t) = expect_body_t { // TODO: expect_body_t is smaller for constants // TODO: 定数の場合、expect_body_tのほうが小さくなってしまう @@ -2302,20 +2308,35 @@ impl GenericASTLowerer { found_body_t, ) { errors.push(e); + no_reassign = true; } } } - let vi = match self.module.context.outer.as_mut().unwrap().assign_var_sig( - &sig, - found_body_t, - body.id, - block.last(), - None, - ) { - Ok(vi) => vi, - Err(errs) => { - errors.extend(errs); - VarInfo::ILLEGAL + let found_body_t = if sig.is_phi() { + self.module + .context + .union(existing_t.as_ref().unwrap_or(&Type::Never), found_body_t) + } else { + found_body_t.clone() + }; + let vi = if no_reassign { + VarInfo { + t: found_body_t, + ..existing_vi.unwrap_or_default() + } + } else { + match self.module.context.outer.as_mut().unwrap().assign_var_sig( + &sig, + &found_body_t, + body.id, + block.last(), + None, + ) { + Ok(vi) => vi, + Err(errs) => { + errors.extend(errs); + VarInfo::ILLEGAL + } } }; let ident = hir::Identifier::new(ident, None, vi); @@ -2351,7 +2372,7 @@ impl GenericASTLowerer { errors.extend(errs); let found_body_t = block.ref_t(); let ident = match &sig.pat { - ast::VarPattern::Ident(ident) => ident.clone(), + ast::VarPattern::Ident(ident) | ast::VarPattern::Phi(ident) => ident.clone(), ast::VarPattern::Discard(token) => { ast::Identifier::private_from_token(token.clone()) } @@ -2366,9 +2387,16 @@ impl GenericASTLowerer { .map_err(|errs| (None, errors.concat(errs))); } }; + let found_body_t = if sig.is_phi() { + self.module + .context + .union(existing_t.as_ref().unwrap_or(&Type::Never), found_body_t) + } else { + found_body_t.clone() + }; if let Err(errs) = self.module.context.outer.as_mut().unwrap().assign_var_sig( &sig, - found_body_t, + &found_body_t, ast::DefId(0), None, None, diff --git a/crates/erg_parser/ast.rs b/crates/erg_parser/ast.rs index bad87b48..55be385a 100644 --- a/crates/erg_parser/ast.rs +++ b/crates/erg_parser/ast.rs @@ -4723,6 +4723,10 @@ pub enum VarPattern { Discard(Token), Glob(Token), Ident(Identifier), + /// Used when a different value is assigned in a branch other than `Ident`. + /// (e.g. the else variable when a variable is defined with Python if-else) + /// Not used in Erg mode at this time + Phi(Identifier), /// e.g. `[x, y, z]` of `[x, y, z] = [1, 2, 3]` List(VarListPattern), /// e.g. `(x, y, z)` of `(x, y, z) = (1, 2, 3)` @@ -4739,6 +4743,7 @@ impl NestedDisplay for VarPattern { Self::Discard(_) => write!(f, "_"), Self::Glob(_) => write!(f, "*"), Self::Ident(ident) => write!(f, "{ident}"), + Self::Phi(ident) => write!(f, "(phi){ident}"), Self::List(l) => write!(f, "{l}"), Self::Tuple(t) => write!(f, "{t}"), Self::Record(r) => write!(f, "{r}"), @@ -4748,9 +4753,9 @@ impl NestedDisplay for VarPattern { } impl_display_from_nested!(VarPattern); -impl_locational_for_enum!(VarPattern; Discard, Glob, Ident, List, Tuple, Record, DataPack); -impl_into_py_for_enum!(VarPattern; Discard, Glob, Ident, List, Tuple, Record, DataPack); -impl_from_py_for_enum!(VarPattern; Discard(Token), Glob(Token), Ident(Identifier), List(VarListPattern), Tuple(VarTuplePattern), Record(VarRecordPattern), DataPack(VarDataPackPattern)); +impl_locational_for_enum!(VarPattern; Discard, Glob, Ident, Phi, List, Tuple, Record, DataPack); +impl_into_py_for_enum!(VarPattern; Discard, Glob, Ident, Phi, List, Tuple, Record, DataPack); +impl_from_py_for_enum!(VarPattern; Discard(Token), Glob(Token), Ident(Identifier), Phi(Identifier), List(VarListPattern), Tuple(VarTuplePattern), Record(VarRecordPattern), DataPack(VarDataPackPattern)); impl VarPattern { pub const fn inspect(&self) -> Option<&Str> { @@ -4900,10 +4905,14 @@ impl VarSignature { pub fn ident(&self) -> Option<&Identifier> { match &self.pat { - VarPattern::Ident(ident) => Some(ident), + VarPattern::Ident(ident) | VarPattern::Phi(ident) => Some(ident), _ => None, } } + + pub fn is_phi(&self) -> bool { + matches!(self.pat, VarPattern::Phi(_)) + } } #[pyclass] diff --git a/crates/erg_parser/desugar.rs b/crates/erg_parser/desugar.rs index 9a6fc170..79fd79ac 100644 --- a/crates/erg_parser/desugar.rs +++ b/crates/erg_parser/desugar.rs @@ -763,7 +763,10 @@ impl Desugarer { self.desugar_nested_var_pattern(new, rhs, &buf_name, BufIndex::Record(lhs)); } } - VarPattern::Ident(_) | VarPattern::Discard(_) | VarPattern::Glob(_) => { + VarPattern::Ident(_) + | VarPattern::Phi(_) + | VarPattern::Discard(_) + | VarPattern::Glob(_) => { if let VarPattern::Ident(ident) = v.pat { v.pat = VarPattern::Ident(Self::desugar_ident(ident)); } @@ -966,7 +969,10 @@ impl Desugarer { ); } } - VarPattern::Ident(_) | VarPattern::Discard(_) | VarPattern::Glob(_) => { + VarPattern::Ident(_) + | VarPattern::Phi(_) + | VarPattern::Discard(_) + | VarPattern::Glob(_) => { let def = Def::new(Signature::Var(sig.clone()), body); new_module.push(Expr::Def(def)); }