From 4ff0b64fc37ca3ac7fc92ac3abdac203973078bf Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 23 Mar 2023 21:17:37 +0900 Subject: [PATCH] fix: type-instantiating bugs --- crates/erg_common/set.rs | 5 + crates/erg_compiler/context/eval.rs | 2 +- crates/erg_compiler/context/generalize.rs | 249 ++++++++---------- crates/erg_compiler/context/inquire.rs | 6 +- crates/erg_compiler/context/instantiate.rs | 50 +++- .../erg_compiler/context/instantiate_spec.rs | 206 +++++++++++---- crates/erg_compiler/context/unify.rs | 8 +- crates/erg_compiler/ty/mod.rs | 50 +++- 8 files changed, 357 insertions(+), 219 deletions(-) diff --git a/crates/erg_common/set.rs b/crates/erg_common/set.rs index 2abab41b..a55c235e 100644 --- a/crates/erg_common/set.rs +++ b/crates/erg_common/set.rs @@ -166,6 +166,11 @@ impl Set { self.elems.extend(iter); } + pub fn extended>(mut self, iter: I) -> Self { + self.elems.extend(iter); + self + } + #[inline] pub fn is_superset(&self, other: &Set) -> bool { self.elems.is_superset(&other.elems) diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 37116300..05839905 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -1393,7 +1393,7 @@ impl Context { tp.clone() } else { let tp = TyParam::FreeVar(new_fv); - tv_cache.push_or_init_typaram(&name, &tp); + tv_cache.push_or_init_typaram(&name, &tp, self); tp } } diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index ab0037ca..5447a6ef 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -20,72 +20,75 @@ use crate::{feature_error, hir}; use Type::*; use Variance::*; -impl Context { - pub const TOP_LEVEL: usize = 1; +pub struct Generalizer { + level: usize, + variance: Variance, + qnames: Set, + structural_inner: bool, +} - fn generalize_tp( - &self, - free: TyParam, - variance: Variance, - qnames: &Set, - uninit: bool, - ) -> TyParam { +impl Generalizer { + pub fn new(level: usize) -> Self { + Self { + level, + variance: Covariant, + qnames: set! {}, + structural_inner: false, + } + } + + fn generalize_tp(&mut self, free: TyParam, uninit: bool) -> TyParam { match free { - TyParam::Type(t) => TyParam::t(self.generalize_t_inner(*t, variance, qnames, uninit)), + TyParam::Type(t) => TyParam::t(self.generalize_t(*t, uninit)), TyParam::FreeVar(fv) if fv.is_generalized() => TyParam::FreeVar(fv), TyParam::FreeVar(fv) if fv.is_linked() => { - self.generalize_tp(fv.crack().clone(), variance, qnames, uninit) + self.generalize_tp(fv.crack().clone(), uninit) } // TODO: Polymorphic generalization TyParam::FreeVar(fv) if fv.level() > Some(self.level) => { - let constr = self.generalize_constraint(&fv, qnames, variance); + let constr = self.generalize_constraint(&fv); fv.update_constraint(constr, true); fv.generalize(); TyParam::FreeVar(fv) } TyParam::Array(tps) => TyParam::Array( tps.into_iter() - .map(|tp| self.generalize_tp(tp, variance, qnames, uninit)) + .map(|tp| self.generalize_tp(tp, uninit)) .collect(), ), TyParam::Tuple(tps) => TyParam::Tuple( tps.into_iter() - .map(|tp| self.generalize_tp(tp, variance, qnames, uninit)) + .map(|tp| self.generalize_tp(tp, uninit)) .collect(), ), TyParam::Dict(tps) => TyParam::Dict( tps.into_iter() - .map(|(k, v)| { - ( - self.generalize_tp(k, variance, qnames, uninit), - self.generalize_tp(v, variance, qnames, uninit), - ) - }) + .map(|(k, v)| (self.generalize_tp(k, uninit), self.generalize_tp(v, uninit))) .collect(), ), TyParam::Record(rec) => TyParam::Record( rec.into_iter() - .map(|(field, tp)| (field, self.generalize_tp(tp, variance, qnames, uninit))) + .map(|(field, tp)| (field, self.generalize_tp(tp, uninit))) .collect(), ), TyParam::Lambda(lambda) => { let nd_params = lambda .nd_params .into_iter() - .map(|pt| pt.map_type(|t| self.generalize_t_inner(t, variance, qnames, uninit))) + .map(|pt| pt.map_type(|t| self.generalize_t(t, uninit))) .collect::>(); - let var_params = lambda.var_params.map(|pt| { - pt.map_type(|t| self.generalize_t_inner(t, variance, qnames, uninit)) - }); + let var_params = lambda + .var_params + .map(|pt| pt.map_type(|t| self.generalize_t(t, uninit))); let d_params = lambda .d_params .into_iter() - .map(|pt| pt.map_type(|t| self.generalize_t_inner(t, variance, qnames, uninit))) + .map(|pt| pt.map_type(|t| self.generalize_t(t, uninit))) .collect::>(); let body = lambda .body .into_iter() - .map(|tp| self.generalize_tp(tp, variance, qnames, uninit)) + .map(|tp| self.generalize_tp(tp, uninit)) .collect(); TyParam::Lambda(TyParamLambda::new( lambda.const_, @@ -97,26 +100,24 @@ impl Context { } TyParam::FreeVar(_) => free, TyParam::Proj { obj, attr } => { - let obj = self.generalize_tp(*obj, variance, qnames, uninit); + let obj = self.generalize_tp(*obj, uninit); TyParam::proj(obj, attr) } - TyParam::Erased(t) => { - TyParam::erased(self.generalize_t_inner(*t, variance, qnames, uninit)) - } + TyParam::Erased(t) => TyParam::erased(self.generalize_t(*t, uninit)), TyParam::App { name, args } => { let args = args .into_iter() - .map(|tp| self.generalize_tp(tp, variance, qnames, uninit)) + .map(|tp| self.generalize_tp(tp, uninit)) .collect(); TyParam::App { name, args } } TyParam::BinOp { op, lhs, rhs } => { - let lhs = self.generalize_tp(*lhs, variance, qnames, uninit); - let rhs = self.generalize_tp(*rhs, variance, qnames, uninit); + let lhs = self.generalize_tp(*lhs, uninit); + let rhs = self.generalize_tp(*rhs, uninit); TyParam::bin(op, lhs, rhs) } TyParam::UnaryOp { op, val } => { - let val = self.generalize_tp(*val, variance, qnames, uninit); + let val = self.generalize_tp(*val, uninit); TyParam::unary(op, val) } other if other.has_no_unbound_var() => other, @@ -124,17 +125,6 @@ impl Context { } } - /// Quantification occurs only once in function types. - /// Therefore, this method is called only once at the top level, and `generalize_t_inner` is called inside. - pub(crate) fn generalize_t(&self, free_type: Type) -> Type { - let maybe_unbound_t = self.generalize_t_inner(free_type, Covariant, &set! {}, false); - if maybe_unbound_t.is_subr() && maybe_unbound_t.has_qvar() { - maybe_unbound_t.quantify() - } else { - maybe_unbound_t - } - } - /// see doc/LANG/compiler/inference.md#一般化 for details /// ```python /// generalize_t(?T) == 'T: Type @@ -142,17 +132,9 @@ impl Context { /// generalize_t(?T(<: Add(?T(<: Eq(?T(<: ...)))) -> ?T) == |'T <: Add('T)| 'T -> 'T /// generalize_t(?T(<: TraitX) -> Int) == TraitX -> Int // 戻り値に現れないなら量化しない /// ``` - fn generalize_t_inner( - &self, - free_type: Type, - variance: Variance, - qnames: &Set, - uninit: bool, - ) -> Type { + fn generalize_t(&mut self, free_type: Type, uninit: bool) -> Type { match free_type { - FreeVar(fv) if fv.is_linked() => { - self.generalize_t_inner(fv.crack().clone(), variance, qnames, uninit) - } + FreeVar(fv) if fv.is_linked() => self.generalize_t(fv.crack().clone(), uninit), FreeVar(fv) if fv.is_generalized() => Type::FreeVar(fv), // TODO: Polymorphic generalization FreeVar(fv) if fv.level().unwrap() > self.level => { @@ -163,63 +145,49 @@ impl Context { if let Some((sub, sup)) = fv.get_subsup() { // |Int <: T <: Int| T -> T ==> Int -> Int if sub == sup { - let t = self.generalize_t_inner(sub, variance, qnames, uninit); + let t = self.generalize_t(sub, uninit); fv.forced_link(&t); FreeVar(fv) } else if sup != Obj - && !qnames.contains(&fv.unbound_name().unwrap()) - && variance == Contravariant + && !self.qnames.contains(&fv.unbound_name().unwrap()) + && self.variance == Contravariant { // |T <: Bool| T -> Int ==> Bool -> Int - self.generalize_t_inner(sup, variance, qnames, uninit) + self.generalize_t(sup, uninit) } else if sub != Never - && !qnames.contains(&fv.unbound_name().unwrap()) - && variance == Covariant + && !self.qnames.contains(&fv.unbound_name().unwrap()) + && self.variance == Covariant { // |T :> Int| X -> T ==> X -> Int - self.generalize_t_inner(sub, variance, qnames, uninit) + self.generalize_t(sub, uninit) } else { - fv.update_constraint( - self.generalize_constraint(&fv, qnames, variance), - true, - ); + fv.update_constraint(self.generalize_constraint(&fv), true); fv.generalize(); Type::FreeVar(fv) } } else { // ?S(: Str) => 'S - fv.update_constraint(self.generalize_constraint(&fv, qnames, variance), true); + fv.update_constraint(self.generalize_constraint(&fv), true); fv.generalize(); Type::FreeVar(fv) } } Subr(mut subr) => { + self.variance = Contravariant; let qnames = subr.essential_qnames(); + self.qnames.extend(qnames.clone()); subr.non_default_params.iter_mut().for_each(|nd_param| { - *nd_param.typ_mut() = self.generalize_t_inner( - mem::take(nd_param.typ_mut()), - Contravariant, - &qnames, - uninit, - ); + *nd_param.typ_mut() = self.generalize_t(mem::take(nd_param.typ_mut()), uninit); }); if let Some(var_args) = &mut subr.var_params { - *var_args.typ_mut() = self.generalize_t_inner( - mem::take(var_args.typ_mut()), - Contravariant, - &qnames, - uninit, - ); + *var_args.typ_mut() = self.generalize_t(mem::take(var_args.typ_mut()), uninit); } subr.default_params.iter_mut().for_each(|d_param| { - *d_param.typ_mut() = self.generalize_t_inner( - mem::take(d_param.typ_mut()), - Contravariant, - &qnames, - uninit, - ); + *d_param.typ_mut() = self.generalize_t(mem::take(d_param.typ_mut()), uninit); }); - let return_t = self.generalize_t_inner(*subr.return_t, Covariant, &qnames, uninit); + self.variance = Covariant; + let return_t = self.generalize_t(*subr.return_t, uninit); + self.qnames = self.qnames.difference(&qnames); subr_t( subr.kind, subr.non_default_params, @@ -231,34 +199,30 @@ impl Context { Record(rec) => { let fields = rec .into_iter() - .map(|(name, t)| (name, self.generalize_t_inner(t, variance, qnames, uninit))) + .map(|(name, t)| (name, self.generalize_t(t, uninit))) .collect(); Type::Record(fields) } Callable { .. } => todo!(), - Ref(t) => ref_(self.generalize_t_inner(*t, variance, qnames, uninit)), + Ref(t) => ref_(self.generalize_t(*t, uninit)), RefMut { before, after } => { - let after = - after.map(|aft| self.generalize_t_inner(*aft, variance, qnames, uninit)); - ref_mut( - self.generalize_t_inner(*before, variance, qnames, uninit), - after, - ) + let after = after.map(|aft| self.generalize_t(*aft, uninit)); + ref_mut(self.generalize_t(*before, uninit), after) } Refinement(refine) => { - let t = self.generalize_t_inner(*refine.t, variance, qnames, uninit); - let pred = self.generalize_pred(*refine.pred, variance, qnames, uninit); + let t = self.generalize_t(*refine.t, uninit); + let pred = self.generalize_pred(*refine.pred, uninit); refinement(refine.var, t, pred) } Poly { name, mut params } => { let params = params .iter_mut() - .map(|p| self.generalize_tp(mem::take(p), variance, qnames, uninit)) + .map(|p| self.generalize_tp(mem::take(p), uninit)) .collect::>(); poly(name, params) } Proj { lhs, rhs } => { - let lhs = self.generalize_t_inner(*lhs, variance, qnames, uninit); + let lhs = self.generalize_t(*lhs, uninit); proj(lhs, rhs) } ProjCall { @@ -266,98 +230,111 @@ impl Context { attr_name, mut args, } => { - let lhs = self.generalize_tp(*lhs, variance, qnames, uninit); + let lhs = self.generalize_tp(*lhs, uninit); for arg in args.iter_mut() { - *arg = self.generalize_tp(mem::take(arg), variance, qnames, uninit); + *arg = self.generalize_tp(mem::take(arg), uninit); } proj_call(lhs, attr_name, args) } And(l, r) => { - let l = self.generalize_t_inner(*l, variance, qnames, uninit); - let r = self.generalize_t_inner(*r, variance, qnames, uninit); + let l = self.generalize_t(*l, uninit); + let r = self.generalize_t(*r, uninit); // not `self.intersection` because types are generalized and(l, r) } Or(l, r) => { - let l = self.generalize_t_inner(*l, variance, qnames, uninit); - let r = self.generalize_t_inner(*r, variance, qnames, uninit); + let l = self.generalize_t(*l, uninit); + let r = self.generalize_t(*r, uninit); // not `self.union` because types are generalized or(l, r) } - Not(l) => not(self.generalize_t_inner(*l, variance, qnames, uninit)), - Structural(t) => self - .generalize_t_inner(*t, variance, qnames, uninit) - .structuralize(), + Not(l) => not(self.generalize_t(*l, uninit)), + Structural(ty) => { + if self.structural_inner { + ty.structuralize() + } else { + if ty.is_recursive() { + self.structural_inner = true; + } + let res = self.generalize_t(*ty, uninit).structuralize(); + self.structural_inner = false; + res + } + } // REVIEW: その他何でもそのまま通していいのか? other => other, } } - fn generalize_constraint( - &self, - fv: &Free, - qnames: &Set, - variance: Variance, - ) -> Constraint { + fn generalize_constraint(&mut self, fv: &Free) -> Constraint { if let Some((sub, sup)) = fv.get_subsup() { - let sub = self.generalize_t_inner(sub, variance, qnames, true); - let sup = self.generalize_t_inner(sup, variance, qnames, true); + let sub = self.generalize_t(sub, true); + let sup = self.generalize_t(sup, true); Constraint::new_sandwiched(sub, sup) } else if let Some(ty) = fv.get_type() { - let t = self.generalize_t_inner(ty, variance, qnames, true); + let t = self.generalize_t(ty, true); Constraint::new_type_of(t) } else { unreachable!() } } - fn generalize_pred( - &self, - pred: Predicate, - variance: Variance, - qnames: &Set, - uninit: bool, - ) -> Predicate { + fn generalize_pred(&mut self, pred: Predicate, uninit: bool) -> Predicate { match pred { Predicate::Const(_) => pred, Predicate::Value(ValueObj::Type(mut typ)) => { - *typ.typ_mut() = - self.generalize_t_inner(mem::take(typ.typ_mut()), variance, qnames, uninit); + *typ.typ_mut() = self.generalize_t(mem::take(typ.typ_mut()), uninit); Predicate::Value(ValueObj::Type(typ)) } Predicate::Value(_) => pred, Predicate::Equal { lhs, rhs } => { - let rhs = self.generalize_tp(rhs, variance, qnames, uninit); + let rhs = self.generalize_tp(rhs, uninit); Predicate::eq(lhs, rhs) } Predicate::GreaterEqual { lhs, rhs } => { - let rhs = self.generalize_tp(rhs, variance, qnames, uninit); + let rhs = self.generalize_tp(rhs, uninit); Predicate::ge(lhs, rhs) } Predicate::LessEqual { lhs, rhs } => { - let rhs = self.generalize_tp(rhs, variance, qnames, uninit); + let rhs = self.generalize_tp(rhs, uninit); Predicate::le(lhs, rhs) } Predicate::NotEqual { lhs, rhs } => { - let rhs = self.generalize_tp(rhs, variance, qnames, uninit); + let rhs = self.generalize_tp(rhs, uninit); Predicate::ne(lhs, rhs) } Predicate::And(lhs, rhs) => { - let lhs = self.generalize_pred(*lhs, variance, qnames, uninit); - let rhs = self.generalize_pred(*rhs, variance, qnames, uninit); + let lhs = self.generalize_pred(*lhs, uninit); + let rhs = self.generalize_pred(*rhs, uninit); Predicate::and(lhs, rhs) } Predicate::Or(lhs, rhs) => { - let lhs = self.generalize_pred(*lhs, variance, qnames, uninit); - let rhs = self.generalize_pred(*rhs, variance, qnames, uninit); + let lhs = self.generalize_pred(*lhs, uninit); + let rhs = self.generalize_pred(*rhs, uninit); Predicate::or(lhs, rhs) } Predicate::Not(pred) => { - let pred = self.generalize_pred(*pred, variance, qnames, uninit); + let pred = self.generalize_pred(*pred, uninit); !pred } } } +} + +impl Context { + pub const TOP_LEVEL: usize = 1; + + /// Quantification occurs only once in function types. + /// Therefore, this method is called only once at the top level, and `generalize_t_inner` is called inside. + pub(crate) fn generalize_t(&self, free_type: Type) -> Type { + let mut generalizer = Generalizer::new(self.level); + let maybe_unbound_t = generalizer.generalize_t(free_type, false); + if maybe_unbound_t.is_subr() && maybe_unbound_t.has_qvar() { + maybe_unbound_t.quantify() + } else { + maybe_unbound_t + } + } pub(crate) fn deref_tp( &self, diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index d2fb7e8e..cdd6f488 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -2033,7 +2033,7 @@ impl Context { } } Type::FreeVar(fv) => { - let sup = fv.get_super().unwrap(); + let sup = fv.get_super()?; if let Some(res) = self.get_nominal_type_ctx(&sup) { return Some(res); } @@ -2727,9 +2727,9 @@ impl Context { } } - pub(crate) fn get_tp_from_name( + pub(crate) fn get_tp_from_tv_cache( &self, - name: &Str, + name: &str, tmp_tv_cache: &TyVarCache, ) -> Option { if let Some(tp) = tmp_tv_cache.get_typaram(name) { diff --git a/crates/erg_compiler/context/instantiate.rs b/crates/erg_compiler/context/instantiate.rs index 6e8d3028..212e283c 100644 --- a/crates/erg_compiler/context/instantiate.rs +++ b/crates/erg_compiler/context/instantiate.rs @@ -80,6 +80,15 @@ impl TyVarCache { } } + pub fn purge(&mut self, other: &Self) { + for name in other.tyvar_instances.keys() { + self.tyvar_instances.remove(name); + } + for name in other.typaram_instances.keys() { + self.typaram_instances.remove(name); + } + } + fn instantiate_constraint( &mut self, constr: Constraint, @@ -125,10 +134,10 @@ impl TyVarCache { pub(crate) fn push_or_init_tyvar(&mut self, name: &Str, tv: &Type, ctx: &Context) { if let Some(inst) = self.tyvar_instances.get(name) { - self.update_tv(inst, tv, ctx); + self.update_tyvar(inst, tv, ctx); } else if let Some(inst) = self.typaram_instances.get(name) { - if let TyParam::Type(inst) = inst { - self.update_tv(inst, tv, ctx); + if let Ok(inst) = <&Type>::try_from(inst) { + self.update_tyvar(inst, tv, ctx); } else if let TyParam::FreeVar(fv) = inst { fv.link(&TyParam::t(tv.clone())); } else { @@ -139,7 +148,7 @@ impl TyVarCache { } } - fn update_tv(&self, inst: &Type, tv: &Type, ctx: &Context) { + fn update_tyvar(&self, inst: &Type, tv: &Type, ctx: &Context) { // T <: Eq(T) // T is uninitialized // T.link(T); @@ -165,19 +174,34 @@ impl TyVarCache { } } - pub(crate) fn push_or_init_typaram(&mut self, name: &Str, tp: &TyParam) { + pub(crate) fn push_or_init_typaram(&mut self, name: &Str, tp: &TyParam, ctx: &Context) { // FIXME: - if let Some(_tp) = self.typaram_instances.get(name) { - panic!("{_tp} {tp}"); - // return; - } else if let Some(_t) = self.tyvar_instances.get(name) { - panic!("{_t} {tp}"); - // return; + if let Some(inst) = self.typaram_instances.get(name) { + self.update_typaram(inst, tp, ctx); + } else if let Some(inst) = self.tyvar_instances.get(name) { + if let Ok(tv) = <&Type>::try_from(tp) { + self.update_tyvar(inst, tv, ctx); + } else { + unreachable!() + } } else { self.typaram_instances.insert(name.clone(), tp.clone()); } } + fn update_typaram(&self, inst: &TyParam, tp: &TyParam, ctx: &Context) { + let inst = enum_unwrap!(inst, TyParam::FreeVar); + if inst.constraint_is_uninited() { + inst.link(tp); + } else { + let old_type = inst.get_type().unwrap(); + let tv = enum_unwrap!(tp, TyParam::FreeVar); + let new_type = tv.get_type().unwrap(); + let new_constraint = Constraint::new_type_of(ctx.intersection(&old_type, &new_type)); + inst.update_constraint(new_constraint, true); + } + } + pub(crate) fn appeared(&self, name: &Str) -> bool { self.already_appeared.contains(name) } @@ -240,7 +264,7 @@ impl Context { if tmp_tv_cache.appeared(&name) { let tp = TyParam::named_free_var(name.clone(), self.level, Constraint::Uninited); - tmp_tv_cache.push_or_init_typaram(&name, &tp); + tmp_tv_cache.push_or_init_typaram(&name, &tp, self); return Ok(tp); } if let Some(tv_cache) = &self.tv_cache { @@ -253,7 +277,7 @@ impl Context { tmp_tv_cache.push_appeared(name.clone()); let constr = tmp_tv_cache.instantiate_constraint(constr, self, loc)?; let tp = TyParam::named_free_var(name.clone(), self.level, constr); - tmp_tv_cache.push_or_init_typaram(&name, &tp); + tmp_tv_cache.push_or_init_typaram(&name, &tp, self); Ok(tp) } } diff --git a/crates/erg_compiler/context/instantiate_spec.rs b/crates/erg_compiler/context/instantiate_spec.rs index 5c2805b0..7aa93920 100644 --- a/crates/erg_compiler/context/instantiate_spec.rs +++ b/crates/erg_compiler/context/instantiate_spec.rs @@ -4,7 +4,7 @@ use std::option::Option; // conflicting to Type::Option use erg_common::log; use erg_common::traits::{Locational, Stream}; use erg_common::Str; -use erg_common::{assume_unreachable, dict, enum_unwrap, set, try_map_mut}; +use erg_common::{assume_unreachable, dict, set, try_map_mut}; use ast::{ NonDefaultParamSignature, ParamTySpec, PreDeclTypeSpec, SimpleTypeSpec, TypeBoundSpec, @@ -147,7 +147,7 @@ impl Context { }; if constr.get_sub_sup().is_none() { let tp = TyParam::named_free_var(lhs.inspect().clone(), self.level, constr); - tv_cache.push_or_init_typaram(lhs.inspect(), &tp); + tv_cache.push_or_init_typaram(lhs.inspect(), &tp, self); } else { let tv = named_free_var(lhs.inspect().clone(), self.level, constr); tv_cache.push_or_init_tyvar(lhs.inspect(), &tv, self); @@ -388,6 +388,9 @@ impl Context { Ok(spec_t) } + /// Given the type `T -> U`, if `T` is a known type, then this is a function type that takes `T` and returns `U`. + /// If the type `T` is not defined, then `T` is considered a constant parameter. + /// FIXME: The type bounds are processed regardless of the order in the specification, but in the current implementation, undefined type may be considered a constant parameter. pub(crate) fn instantiate_param_ty( &self, sig: &NonDefaultParamSignature, @@ -403,7 +406,7 @@ impl Context { return Ok(ParamTy::Pos(v_enum(set! { value }))); } else if let Some(tp) = sig .name() - .and_then(|name| self.get_tp_from_name(name.inspect(), tmp_tv_cache)) + .and_then(|name| self.get_tp_from_tv_cache(name.inspect(), tmp_tv_cache)) { match tp { TyParam::Type(t) => return Ok(ParamTy::Pos(*t)), @@ -439,9 +442,12 @@ impl Context { } ast::PreDeclTypeSpec::Attr { namespace, t } => { if let Ok(receiver) = Parser::validate_const_expr(namespace.as_ref().clone()) { - if let Ok(receiver_t) = - self.instantiate_const_expr_as_type(&receiver, None, tmp_tv_cache) - { + if let Ok(receiver_t) = self.instantiate_const_expr_as_type( + &receiver, + None, + tmp_tv_cache, + not_found_is_qvar, + ) { let rhs = t.ident.inspect(); return Ok(proj(receiver_t, rhs)); } @@ -494,9 +500,19 @@ impl Context { // TODO: kw let mut args = simple.args.pos_args(); if let Some(first) = args.next() { - let t = self.instantiate_const_expr_as_type(&first.expr, None, tmp_tv_cache)?; + let t = self.instantiate_const_expr_as_type( + &first.expr, + None, + tmp_tv_cache, + not_found_is_qvar, + )?; let len = if let Some(len) = args.next() { - self.instantiate_const_expr(&len.expr, None, tmp_tv_cache)? + self.instantiate_const_expr( + &len.expr, + None, + tmp_tv_cache, + not_found_is_qvar, + )? } else { TyParam::erased(Nat) }; @@ -517,7 +533,12 @@ impl Context { vec![Str::from("T")], ))); }; - let t = self.instantiate_const_expr_as_type(&first.expr, None, tmp_tv_cache)?; + let t = self.instantiate_const_expr_as_type( + &first.expr, + None, + tmp_tv_cache, + not_found_is_qvar, + )?; Ok(ref_(t)) } "RefMut" => { @@ -533,7 +554,12 @@ impl Context { vec![Str::from("T")], ))); }; - let t = self.instantiate_const_expr_as_type(&first.expr, None, tmp_tv_cache)?; + let t = self.instantiate_const_expr_as_type( + &first.expr, + None, + tmp_tv_cache, + not_found_is_qvar, + )?; Ok(ref_mut(t, None)) } "Structural" => { @@ -548,7 +574,12 @@ impl Context { vec![Str::from("Type")], ))); }; - let t = self.instantiate_const_expr_as_type(&first.expr, None, tmp_tv_cache)?; + let t = self.instantiate_const_expr_as_type( + &first.expr, + None, + tmp_tv_cache, + not_found_is_qvar, + )?; Ok(t.structuralize()) } "Self" => self.rec_get_self_t().ok_or_else(|| { @@ -559,19 +590,8 @@ impl Context { )) }), other if simple.args.is_empty() => { - if let Some(t) = tmp_tv_cache.get_tyvar(other) { - return Ok(t.clone()); - } else if let Some(tp) = tmp_tv_cache.get_typaram(other) { - let t = enum_unwrap!(tp, TyParam::Type); - return Ok(t.as_ref().clone()); - } - if let Some(tv_cache) = &self.tv_cache { - if let Some(t) = tv_cache.get_tyvar(other) { - return Ok(t.clone()); - } else if let Some(tp) = tv_cache.get_typaram(other) { - let t = enum_unwrap!(tp, TyParam::Type); - return Ok(t.as_ref().clone()); - } + if let Some(TyParam::Type(t)) = self.get_tp_from_tv_cache(other, tmp_tv_cache) { + return Ok(*t); } if let Some(outer) = &self.outer { if let Ok(t) = outer.instantiate_simple_t( @@ -619,8 +639,12 @@ impl Context { // FIXME: kw args let mut new_params = vec![]; for (i, arg) in simple.args.pos_args().enumerate() { - let params = - self.instantiate_const_expr(&arg.expr, Some((ctx, i)), tmp_tv_cache); + let params = self.instantiate_const_expr( + &arg.expr, + Some((ctx, i)), + tmp_tv_cache, + not_found_is_qvar, + ); let params = params.or_else(|e| { if not_found_is_qvar { let name = arg.expr.to_string(); @@ -631,7 +655,7 @@ impl Context { self.level, Constraint::Uninited, ); - tmp_tv_cache.push_or_init_typaram(&name, &tp); + tmp_tv_cache.push_or_init_typaram(&name, &tp, self); Ok(tp) } else { Err(e) @@ -651,6 +675,7 @@ impl Context { erased_idx: Option<(&Context, usize)>, tmp_tv_cache: &mut TyVarCache, loc: &impl Locational, + not_found_is_qvar: bool, ) -> TyCheckResult { if &name[..] == "_" { let t = if let Some((ctx, i)) = erased_idx { @@ -660,12 +685,17 @@ impl Context { }; return Ok(TyParam::erased(t)); } - if let Some(tp) = self.get_tp_from_name(name, tmp_tv_cache) { + if let Some(tp) = self.get_tp_from_tv_cache(name, tmp_tv_cache) { return Ok(tp); } if let Some(value) = self.rec_get_const_obj(name) { return Ok(TyParam::Value(value.clone())); } + if not_found_is_qvar { + let tyvar = named_free_var(name.clone(), self.level, Constraint::Uninited); + tmp_tv_cache.push_or_init_tyvar(name, &tyvar, self); + return Ok(TyParam::t(tyvar)); + } Err(TyCheckErrors::from(TyCheckError::no_var_error( self.cfg.input.clone(), line!() as usize, @@ -681,23 +711,39 @@ impl Context { expr: &ast::ConstExpr, erased_idx: Option<(&Context, usize)>, tmp_tv_cache: &mut TyVarCache, + not_found_is_qvar: bool, ) -> TyCheckResult { match expr { ast::ConstExpr::Lit(lit) => Ok(TyParam::Value(self.eval_lit(lit)?)), // TODO: inc_ref ast::ConstExpr::Accessor(ast::ConstAccessor::Attr(attr)) => { - let obj = self.instantiate_const_expr(&attr.obj, erased_idx, tmp_tv_cache)?; + let obj = self.instantiate_const_expr( + &attr.obj, + erased_idx, + tmp_tv_cache, + not_found_is_qvar, + )?; Ok(obj.proj(attr.name.inspect())) } ast::ConstExpr::Accessor(ast::ConstAccessor::Local(local)) => { self.inc_ref_local(local, self); - self.instantiate_local(local.inspect(), erased_idx, tmp_tv_cache, local) + self.instantiate_local( + local.inspect(), + erased_idx, + tmp_tv_cache, + local, + not_found_is_qvar, + ) } ast::ConstExpr::Array(array) => { let mut tp_arr = vec![]; for (i, elem) in array.elems.pos_args().enumerate() { - let el = - self.instantiate_const_expr(&elem.expr, Some((self, i)), tmp_tv_cache)?; + let el = self.instantiate_const_expr( + &elem.expr, + Some((self, i)), + tmp_tv_cache, + not_found_is_qvar, + )?; tp_arr.push(el); } Ok(TyParam::Array(tp_arr)) @@ -705,8 +751,12 @@ impl Context { ast::ConstExpr::Set(set) => { let mut tp_set = set! {}; for (i, elem) in set.elems.pos_args().enumerate() { - let el = - self.instantiate_const_expr(&elem.expr, Some((self, i)), tmp_tv_cache)?; + let el = self.instantiate_const_expr( + &elem.expr, + Some((self, i)), + tmp_tv_cache, + not_found_is_qvar, + )?; tp_set.insert(el); } Ok(TyParam::Set(tp_set)) @@ -714,10 +764,18 @@ impl Context { ast::ConstExpr::Dict(dict) => { let mut tp_dict = dict! {}; for (i, elem) in dict.kvs.iter().enumerate() { - let key = - self.instantiate_const_expr(&elem.key, Some((self, i)), tmp_tv_cache)?; - let val = - self.instantiate_const_expr(&elem.value, Some((self, i)), tmp_tv_cache)?; + let key = self.instantiate_const_expr( + &elem.key, + Some((self, i)), + tmp_tv_cache, + not_found_is_qvar, + )?; + let val = self.instantiate_const_expr( + &elem.value, + Some((self, i)), + tmp_tv_cache, + not_found_is_qvar, + )?; tp_dict.insert(key, val); } Ok(TyParam::Dict(tp_dict)) @@ -725,8 +783,12 @@ impl Context { ast::ConstExpr::Tuple(tuple) => { let mut tp_tuple = vec![]; for (i, elem) in tuple.elems.pos_args().enumerate() { - let el = - self.instantiate_const_expr(&elem.expr, Some((self, i)), tmp_tv_cache)?; + let el = self.instantiate_const_expr( + &elem.expr, + Some((self, i)), + tmp_tv_cache, + not_found_is_qvar, + )?; tp_tuple.push(el); } Ok(TyParam::Tuple(tp_tuple)) @@ -739,20 +801,18 @@ impl Context { attr.body.block.get(0).unwrap(), None, tmp_tv_cache, + not_found_is_qvar, )?; tp_rec.insert(field, val); } Ok(TyParam::Record(tp_rec)) } ast::ConstExpr::Lambda(lambda) => { - let mut _tmp_tv_cache = + let _tmp_tv_cache = self.instantiate_ty_bounds(&lambda.sig.bounds, RegistrationMode::Normal)?; - let tmp_tv_cache = if tmp_tv_cache.is_empty() { - &mut _tmp_tv_cache - } else { - // TODO: prohibit double quantification - tmp_tv_cache - }; + // Since there are type variables and other variables that can be constrained within closures, + // they are `merge`d once and then `purge`d of type variables that are only used internally after instantiation. + tmp_tv_cache.merge(&_tmp_tv_cache); let mut nd_params = Vec::with_capacity(lambda.sig.params.non_defaults.len()); for sig in lambda.sig.params.non_defaults.iter() { let pt = self.instantiate_param_ty( @@ -790,9 +850,11 @@ impl Context { } let mut body = vec![]; for expr in lambda.body.iter() { - let param = self.instantiate_const_expr(expr, None, tmp_tv_cache)?; + let param = + self.instantiate_const_expr(expr, None, tmp_tv_cache, not_found_is_qvar)?; body.push(param); } + tmp_tv_cache.purge(&_tmp_tv_cache); Ok(TyParam::Lambda(TyParamLambda::new( lambda.clone(), nd_params, @@ -809,8 +871,18 @@ impl Context { &format!("instantiating const expression {bin}") ) }; - let lhs = self.instantiate_const_expr(&bin.lhs, erased_idx, tmp_tv_cache)?; - let rhs = self.instantiate_const_expr(&bin.rhs, erased_idx, tmp_tv_cache)?; + let lhs = self.instantiate_const_expr( + &bin.lhs, + erased_idx, + tmp_tv_cache, + not_found_is_qvar, + )?; + let rhs = self.instantiate_const_expr( + &bin.rhs, + erased_idx, + tmp_tv_cache, + not_found_is_qvar, + )?; Ok(TyParam::bin(op, lhs, rhs)) } ast::ConstExpr::UnaryOp(unary) => { @@ -821,11 +893,21 @@ impl Context { &format!("instantiating const expression {unary}") ) }; - let val = self.instantiate_const_expr(&unary.expr, erased_idx, tmp_tv_cache)?; + let val = self.instantiate_const_expr( + &unary.expr, + erased_idx, + tmp_tv_cache, + not_found_is_qvar, + )?; Ok(TyParam::unary(op, val)) } ast::ConstExpr::TypeAsc(tasc) => { - let tp = self.instantiate_const_expr(&tasc.expr, erased_idx, tmp_tv_cache)?; + let tp = self.instantiate_const_expr( + &tasc.expr, + erased_idx, + tmp_tv_cache, + not_found_is_qvar, + )?; let spec_t = self.instantiate_typespec( &tasc.t_spec.t_spec, None, @@ -859,8 +941,9 @@ impl Context { expr: &ast::ConstExpr, erased_idx: Option<(&Context, usize)>, tmp_tv_cache: &mut TyVarCache, + not_found_is_qvar: bool, ) -> TyCheckResult { - let tp = self.instantiate_const_expr(expr, erased_idx, tmp_tv_cache)?; + let tp = self.instantiate_const_expr(expr, erased_idx, tmp_tv_cache, not_found_is_qvar)?; self.instantiate_tp_as_type(tp, expr) } @@ -999,7 +1082,8 @@ impl Context { mode, not_found_is_qvar, )?; - let mut len = self.instantiate_const_expr(&arr.len, None, tmp_tv_cache)?; + let mut len = + self.instantiate_const_expr(&arr.len, None, tmp_tv_cache, not_found_is_qvar)?; if let TyParam::Erased(t) = &mut len { *t.as_mut() = Type::Nat; } @@ -1013,7 +1097,8 @@ impl Context { mode, not_found_is_qvar, )?; - let mut len = self.instantiate_const_expr(&set.len, None, tmp_tv_cache)?; + let mut len = + self.instantiate_const_expr(&set.len, None, tmp_tv_cache, not_found_is_qvar)?; if let TyParam::Erased(t) = &mut len { *t.as_mut() = Type::Nat; } @@ -1074,7 +1159,12 @@ impl Context { TypeSpec::Enum(set) => { let mut new_set = set! {}; for arg in set.pos_args() { - new_set.insert(self.instantiate_const_expr(&arg.expr, None, tmp_tv_cache)?); + new_set.insert(self.instantiate_const_expr( + &arg.expr, + None, + tmp_tv_cache, + not_found_is_qvar, + )?); } let ty = new_set.iter().fold(Type::Never, |t, tp| { self.union(&t, &self.get_tp_t(tp).unwrap()) @@ -1089,9 +1179,9 @@ impl Context { TokenKind::Open => IntervalOp::Open, _ => assume_unreachable!(), }; - let l = self.instantiate_const_expr(lhs, None, tmp_tv_cache)?; + let l = self.instantiate_const_expr(lhs, None, tmp_tv_cache, not_found_is_qvar)?; let l = self.eval_tp(l)?; - let r = self.instantiate_const_expr(rhs, None, tmp_tv_cache)?; + let r = self.instantiate_const_expr(rhs, None, tmp_tv_cache, not_found_is_qvar)?; let r = self.eval_tp(r)?; if let Some(Greater) = self.try_cmp(&l, &r) { panic!("{l}..{r} is not a valid interval type (should be lhs <= rhs)") diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index a757224f..5a39ce02 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -968,13 +968,17 @@ impl Context { (_, Type::RefMut { before, .. }) => self.sub_unify(maybe_sub, before, loc, param_name), (_, Type::Proj { lhs, rhs }) => { if let Ok(evaled) = self.eval_proj(*lhs.clone(), rhs.clone(), self.level, loc) { - self.sub_unify(maybe_sub, &evaled, loc, param_name)?; + if maybe_sup != &evaled { + self.sub_unify(maybe_sub, &evaled, loc, param_name)?; + } } Ok(()) } (Type::Proj { lhs, rhs }, _) => { if let Ok(evaled) = self.eval_proj(*lhs.clone(), rhs.clone(), self.level, loc) { - self.sub_unify(&evaled, maybe_sup, loc, param_name)?; + if maybe_sub != &evaled { + self.sub_unify(&evaled, maybe_sup, loc, param_name)?; + } } Ok(()) } diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index cc5574e1..af391edb 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -412,6 +412,12 @@ impl SubrType { /// essential_qnames(|T, U| (T, T) -> U) == {T} /// ``` pub fn essential_qnames(&self) -> Set { + let structural_qname = self.non_default_params.iter().find_map(|pt| { + pt.typ() + .get_super() + .map_or(false, |t| t.is_structural()) + .then(|| pt.typ().unbound_name().unwrap()) + }); let qnames_sets = self .non_default_params .iter() @@ -419,7 +425,7 @@ impl SubrType { .chain(self.var_params.iter().map(|pt| pt.typ().qnames())) .chain(self.default_params.iter().map(|pt| pt.typ().qnames())) .chain([self.return_t.qnames()]); - Set::multi_intersection(qnames_sets) + Set::multi_intersection(qnames_sets).extended(structural_qname) } pub fn has_qvar(&self) -> bool { @@ -746,6 +752,9 @@ pub enum Type { impl PartialEq for Type { fn eq(&self, other: &Self) -> bool { + if ref_addr_eq!(self, other) { + return true; + } match (self, other) { (Self::Obj, Self::Obj) | (Self::Int, Self::Int) @@ -1589,8 +1598,13 @@ impl Type { } } - pub const fn is_structural(&self) -> bool { - matches!(self, Self::Structural(_)) + pub fn is_structural(&self) -> bool { + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_structural(), + Self::Structural(_) => true, + Self::Refinement(refine) => refine.t.is_structural(), + _ => false, + } } pub fn as_free(&self) -> Option<&FreeTyVar> { @@ -1607,7 +1621,12 @@ impl Type { ref_addr_eq!(fv.forced_as_ref(), target.forced_as_ref()) || fv .get_subsup() - .map(|(sub, sup)| sub.contains_tvar(target) || sup.contains_tvar(target)) + .map(|(sub, sup)| { + fv.forced_undoable_link(&Type::Never); + let res = sub.contains_tvar(target) || sup.contains_tvar(target); + fv.undo(); + res + }) .unwrap_or(false) } Self::Record(rec) => rec.iter().any(|(_, t)| t.contains_tvar(target)), @@ -1642,7 +1661,10 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().contains(target), Self::FreeVar(fv) => { fv.get_subsup().map_or(false, |(sub, sup)| { - sub.contains(target) || sup.contains(target) + fv.forced_undoable_link(&Type::Never); + let res = sub.contains(target) || sup.contains(target); + fv.undo(); + res }) || fv.get_type().map_or(false, |t| t.contains(target)) } Self::Record(rec) => rec.iter().any(|(_, t)| t.contains(target)), @@ -1866,6 +1888,22 @@ impl Type { } } + pub fn get_super(&self) -> Option { + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().get_super(), + Self::FreeVar(fv) if fv.is_unbound() => fv.get_super(), + _ => None, + } + } + + pub fn get_sub(&self) -> Option { + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().get_sub(), + Self::FreeVar(fv) if fv.is_unbound() => fv.get_sub(), + _ => None, + } + } + pub const fn is_free_var(&self) -> bool { matches!(self, Self::FreeVar(_)) } @@ -2017,7 +2055,7 @@ impl Type { Self::FreeVar(fv) if fv.is_generalized() => true, Self::FreeVar(fv) => { if let Some((sub, sup)) = fv.get_subsup() { - fv.undoable_link(&Type::Obj); + fv.undoable_link(&Type::Never); let res_sub = sub.has_qvar(); let res_sup = sup.has_qvar(); fv.undo();