diff --git a/crates/els/server.rs b/crates/els/server.rs index 948679e4..6faed7f6 100644 --- a/crates/els/server.rs +++ b/crates/els/server.rs @@ -478,31 +478,24 @@ impl Server { } pub(crate) fn get_index(&self) -> &SharedModuleIndex { - self.modules - .values() - .next() - .unwrap() - .context - .index() - .unwrap() + self.modules.values().next().unwrap().context.index() } pub(crate) fn get_shared(&self) -> Option<&SharedCompilerResource> { self.modules .values() .next() - .and_then(|module| module.context.shared()) + .map(|module| module.context.shared()) } pub(crate) fn clear_cache(&mut self, uri: &Url) { self.artifacts.remove(uri); if let Some(module) = self.modules.remove(uri) { - if let Some(shared) = module.context.shared() { - let path = util::uri_to_path(uri); - shared.mod_cache.remove(&path); - shared.index.remove_path(&path); - shared.graph.initialize(); - } + let shared = module.context.shared(); + let path = util::uri_to_path(uri); + shared.mod_cache.remove(&path); + shared.index.remove_path(&path); + shared.graph.initialize(); } } } diff --git a/crates/erg_common/error.rs b/crates/erg_common/error.rs index f3c114fe..4c413c0a 100644 --- a/crates/erg_common/error.rs +++ b/crates/erg_common/error.rs @@ -335,6 +335,12 @@ impl PartialOrd for Location { } } +impl Locational for Location { + fn loc(&self) -> Self { + *self + } +} + impl Location { pub fn concat(l: &L, r: &R) -> Self { let l_loc = l.loc(); diff --git a/crates/erg_compiler/codegen.rs b/crates/erg_compiler/codegen.rs index 5de451ba..7dc67462 100644 --- a/crates/erg_compiler/codegen.rs +++ b/crates/erg_compiler/codegen.rs @@ -665,6 +665,15 @@ impl PyCodeGenerator { "int__" | "nat__" | "str__" | "float__" => { self.load_convertors(); } + // NoneType is not defined in the global scope, use `type(None)` instead + "NoneType" => { + self.emit_push_null(); + self.emit_load_name_instr(Identifier::public("type")); + self.emit_load_const(ValueObj::None); + self.emit_precall_and_call(1); + self.stack_dec(); + return; + } _ => {} } let name = self diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 986b5b38..8b10b173 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -19,7 +19,7 @@ use TyParamOrdering::*; use Type::*; use crate::context::cache::{SubtypePair, GLOBAL_TYPE_CACHE}; -use crate::context::{Context, TyVarCache, Variance}; +use crate::context::{Context, Variance}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Credibility { @@ -144,6 +144,7 @@ impl Context { /// e.g. /// Named :> Module /// => Module.super_types == [Named] + /// /// Seq(T) :> Range(T) /// => Range(T).super_types == [Eq, Mutate, Seq('T), Output('T)] pub(crate) fn subtype_of(&self, lhs: &Type, rhs: &Type, allow_cast: bool) -> bool { @@ -218,7 +219,6 @@ impl Context { (Absolutely, true) } (FreeVar(l), FreeVar(r)) => { - log!(err "{l}/{r}/{}", l.structural_eq(r)); if l.structural_eq(r) { (Absolutely, true) } else { @@ -376,6 +376,9 @@ impl Context { match (lhs, rhs) { // Proc :> Func if params are compatible (Subr(ls), Subr(rs)) if ls.kind == rs.kind || ls.kind.is_proc() => { + if !allow_cast && ls.kind != rs.kind { + return false; + } let kw_check = || { for lpt in ls.default_params.iter() { if let Some(rpt) = rs @@ -497,7 +500,7 @@ impl Context { } true } - (Type, Record(rec)) => { + (Type, Record(rec)) if allow_cast => { for (_, t) in rec.iter() { if !self.supertype_of(&Type, t, allow_cast) { return false; @@ -574,11 +577,11 @@ impl Context { } r_preds_clone.is_empty() } - (Nat, re @ Refinement(_)) => { + (Nat, re @ Refinement(_)) if allow_cast => { let nat = Type::Refinement(Nat.into_refinement()); self.structural_supertype_of(&nat, re, allow_cast) } - (re @ Refinement(_), Nat) => { + (re @ Refinement(_), Nat) if allow_cast => { let nat = Type::Refinement(Nat.into_refinement()); self.structural_supertype_of(re, &nat, allow_cast) } @@ -588,7 +591,7 @@ impl Context { // => Eq(Int) :> Eq({1, 2}) :> {1, 2} // => true // Bool :> {1} == true - (l, Refinement(r)) => { + (l, Refinement(r)) if allow_cast => { if self.supertype_of(l, &r.t, allow_cast) { return true; } @@ -600,7 +603,7 @@ impl Context { self.structural_supertype_of(&l, rhs, allow_cast) } // ({I: Int | True} :> Int) == true, ({N: Nat | ...} :> Int) == false, ({I: Int | I >= 0} :> Int) == false - (Refinement(l), r) => { + (Refinement(l), r) if allow_cast => { if l.preds .iter() .any(|p| p.mentions(&l.var) && p.can_be_false()) @@ -609,28 +612,19 @@ impl Context { } self.supertype_of(&l.t, r, allow_cast) } - (Quantified(l), Quantified(r)) => self.structural_subtype_of(l, r, allow_cast), - (Quantified(quant), r) => { - if quant.has_uninited_qvars() { - let mut tmp_tv_cache = TyVarCache::new(self.level, self); - let inst = self - .instantiate_t_inner(*quant.clone(), &mut tmp_tv_cache, &()) - .unwrap(); - self.supertype_of(&inst, r, allow_cast) - } else { - self.supertype_of(quant, r, allow_cast) - } + (Quantified(_), Quantified(_)) => { + let l = self.instantiate_dummy(lhs.clone()); + let r = self.instantiate_dummy(rhs.clone()); + self.sub_unify(&r, &l, &(), None).is_ok() } - (l, Quantified(quant)) => { - if quant.has_uninited_qvars() { - let mut tmp_tv_cache = TyVarCache::new(self.level, self); - let inst = self - .instantiate_t_inner(*quant.clone(), &mut tmp_tv_cache, &()) - .unwrap(); - self.supertype_of(l, &inst, allow_cast) - } else { - self.supertype_of(l, quant, allow_cast) - } + // (|T: Type| T -> T) !<: Obj -> Never + (Quantified(_), r) if allow_cast => { + let inst = self.instantiate_dummy(lhs.clone()); + self.sub_unify(r, &inst, &(), None).is_ok() + } + (l, Quantified(_)) if allow_cast => { + let inst = self.instantiate_dummy(rhs.clone()); + self.sub_unify(&inst, l, &(), None).is_ok() } // Int or Str :> Str or Int == (Int :> Str && Str :> Int) || (Int :> Int && Str :> Str) == true (Or(l_1, l_2), Or(r_1, r_2)) => { @@ -642,11 +636,11 @@ impl Context { (Not(l), Not(r)) => self.subtype_of(l, r, allow_cast), // (Int or Str) :> Nat == Int :> Nat || Str :> Nat == true // (Num or Show) :> Show == Num :> Show || Show :> Num == true - (Or(l_or, r_or), rhs) => { + (Or(l_or, r_or), rhs) if allow_cast => { self.supertype_of(l_or, rhs, allow_cast) || self.supertype_of(r_or, rhs, allow_cast) } // Int :> (Nat or Str) == Int :> Nat && Int :> Str == false - (lhs, Or(l_or, r_or)) => { + (lhs, Or(l_or, r_or)) if allow_cast => { self.supertype_of(lhs, l_or, allow_cast) && self.supertype_of(lhs, r_or, allow_cast) } (And(l_1, l_2), And(r_1, r_2)) => { @@ -655,12 +649,12 @@ impl Context { && self.supertype_of(l_2, r_1, allow_cast)) } // (Num and Show) :> Show == false - (And(l_and, r_and), rhs) => { + (And(l_and, r_and), rhs) if allow_cast => { self.supertype_of(l_and, rhs, allow_cast) && self.supertype_of(r_and, rhs, allow_cast) } // Show :> (Num and Show) == true - (lhs, And(l_and, r_and)) => { + (lhs, And(l_and, r_and)) if allow_cast => { self.supertype_of(lhs, l_and, allow_cast) || self.supertype_of(lhs, r_and, allow_cast) } @@ -668,8 +662,8 @@ impl Context { (Ref(l), Ref(r)) => self.supertype_of(l, r, allow_cast), // TはすべてのRef(T)のメソッドを持つので、Ref(T)のサブタイプ // REVIEW: RefMut is invariant, maybe - (Ref(l), r) => self.supertype_of(l, r, allow_cast), - (RefMut { before: l, .. }, r) => self.supertype_of(l, r, allow_cast), + (Ref(l), r) if allow_cast => self.supertype_of(l, r, allow_cast), + (RefMut { before: l, .. }, r) if allow_cast => self.supertype_of(l, r, allow_cast), // `Eq(Set(T, N)) :> Set(T, N)` will be false, such cases are judged by nominal_supertype_of ( Poly { diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index fc071ed6..393cd3b5 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -948,6 +948,7 @@ impl Context { } } + /// lhs: mainly class pub(crate) fn eval_proj( &self, lhs: Type, diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index 3034350a..05eb8e35 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -872,10 +872,14 @@ impl Context { fn resolve_params_t(&self, params: &mut hir::Params, qnames: &Set) -> TyCheckResult<()> { for param in params.non_defaults.iter_mut() { + // generalization should work properly for the subroutine type, but may not work for the parameters' own types + // HACK: so generalize them manually + param.vi.t.generalize(); param.vi.t = self.deref_tyvar(mem::take(&mut param.vi.t), Contravariant, qnames, param)?; } if let Some(var_params) = &mut params.var_params { + var_params.vi.t.generalize(); var_params.vi.t = self.deref_tyvar( mem::take(&mut var_params.vi.t), Contravariant, @@ -884,6 +888,7 @@ impl Context { )?; } for param in params.defaults.iter_mut() { + param.sig.vi.t.generalize(); param.sig.vi.t = self.deref_tyvar(mem::take(&mut param.sig.vi.t), Contravariant, qnames, param)?; self.resolve_expr_t(&mut param.default_val)?; diff --git a/crates/erg_compiler/context/initialize/mod.rs b/crates/erg_compiler/context/initialize/mod.rs index 1d417c4f..8a98cf04 100644 --- a/crates/erg_compiler/context/initialize/mod.rs +++ b/crates/erg_compiler/context/initialize/mod.rs @@ -29,7 +29,7 @@ use crate::context::instantiate::ConstTemplate; use crate::context::{ ClassDefType, Context, ContextKind, MethodInfo, ModuleContext, ParamSpec, TraitImpl, }; -use crate::module::{SharedCompilerResource, SharedModuleCache}; +use crate::module::SharedCompilerResource; use crate::ty::free::Constraint; use crate::ty::value::ValueObj; use crate::ty::Type; @@ -703,10 +703,10 @@ impl Context { self.consts .insert(name.clone(), ValueObj::builtin_t(t.clone())); for impl_trait in ctx.super_traits.iter() { - if let Some(impls) = self.trait_impls.get_mut(&impl_trait.qual_name()) { + if let Some(impls) = self.trait_impls().get_mut(&impl_trait.qual_name()) { impls.insert(TraitImpl::new(t.clone(), impl_trait.clone())); } else { - self.trait_impls.insert( + self.trait_impls().register( impl_trait.qual_name(), set![TraitImpl::new(t.clone(), impl_trait.clone())], ); @@ -773,10 +773,10 @@ impl Context { self.consts .insert(name.clone(), ValueObj::builtin_t(t.clone())); for impl_trait in ctx.super_traits.iter() { - if let Some(impls) = self.trait_impls.get_mut(&impl_trait.qual_name()) { + if let Some(impls) = self.trait_impls().get_mut(&impl_trait.qual_name()) { impls.insert(TraitImpl::new(t.clone(), impl_trait.clone())); } else { - self.trait_impls.insert( + self.trait_impls().register( impl_trait.qual_name(), set![TraitImpl::new(t.clone(), impl_trait.clone())], ); @@ -837,11 +837,11 @@ impl Context { } } if let ContextKind::GluePatch(tr_inst) = &ctx.kind { - if let Some(impls) = self.trait_impls.get_mut(&tr_inst.sup_trait.qual_name()) { + if let Some(impls) = self.trait_impls().get_mut(&tr_inst.sup_trait.qual_name()) { impls.insert(tr_inst.clone()); } else { - self.trait_impls - .insert(tr_inst.sup_trait.qual_name(), set![tr_inst.clone()]); + self.trait_impls() + .register(tr_inst.sup_trait.qual_name(), set![tr_inst.clone()]); } } self.patches.insert(name, ctx); @@ -896,8 +896,8 @@ impl Context { self.register_builtin_py_impl(ELLIPSIS, Ellipsis, Const, Private, Some(ELLIPSIS)); } - pub(crate) fn init_builtins(cfg: ErgConfig, mod_cache: &SharedModuleCache) { - let mut ctx = Context::builtin_module("", cfg, 100); + pub(crate) fn init_builtins(cfg: ErgConfig, shared: SharedCompilerResource) { + let mut ctx = Context::builtin_module("", cfg, shared.clone(), 100); ctx.init_builtin_consts(); ctx.init_builtin_funcs(); ctx.init_builtin_const_funcs(); @@ -907,7 +907,9 @@ impl Context { ctx.init_builtin_classes(); ctx.init_builtin_patches(); let module = ModuleContext::new(ctx, dict! {}); - mod_cache.register(PathBuf::from(""), None, module); + shared + .mod_cache + .register(PathBuf::from(""), None, module); } pub fn new_module>( diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 15607477..8a314cc7 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -50,8 +50,8 @@ pub enum SubstituteResult { impl Context { pub(crate) fn get_ctx_from_path(&self, path: &Path) -> Option<&Context> { self.mod_cache() - .and_then(|cache| cache.ref_ctx(path)) - .or_else(|| self.py_mod_cache().and_then(|cache| cache.ref_ctx(path))) + .ref_ctx(path) + .or_else(|| self.py_mod_cache().ref_ctx(path)) .map(|mod_ctx| &mod_ctx.context) } @@ -1853,15 +1853,16 @@ impl Context { .map(|(_, ctx)| ctx.super_traits.clone().into_iter()) } + /// include `typ` itself. /// if `typ` is a refinement type, include the base type (refine.t) pub(crate) fn _get_super_classes(&self, typ: &Type) -> Option> { - self.get_nominal_type_ctx(typ).map(|(_, ctx)| { + self.get_nominal_type_ctx(typ).map(|(t, ctx)| { let super_classes = ctx.super_classes.clone(); let derefined = typ.derefine(); if typ != &derefined { - vec![derefined].into_iter().chain(super_classes) + vec![t.clone(), derefined].into_iter().chain(super_classes) } else { - vec![].into_iter().chain(super_classes) + vec![t.clone()].into_iter().chain(super_classes) } }) } @@ -2054,8 +2055,8 @@ impl Context { None } - pub(crate) fn get_trait_impls(&self, t: &Type) -> Set { - match t { + pub(crate) fn get_trait_impls(&self, trait_: &Type) -> Set { + match trait_ { // And(Add, Sub) == intersection({Int <: Add(Int), Bool <: Add(Bool) ...}, {Int <: Sub(Int), ...}) // == {Int <: Add(Int) and Sub(Int), ...} Type::And(l, r) => { @@ -2079,18 +2080,18 @@ impl Context { // FIXME: l_impls.union(&r_impls) } - _ => self.get_simple_trait_impls(t), + _ => self.get_simple_trait_impls(trait_), } } - pub(crate) fn get_simple_trait_impls(&self, t: &Type) -> Set { - let current = if let Some(impls) = self.trait_impls.get(&t.qual_name()) { + pub(crate) fn get_simple_trait_impls(&self, trait_: &Type) -> Set { + let current = if let Some(impls) = self.trait_impls().get(&trait_.qual_name()) { impls.clone() } else { set! {} }; if let Some(outer) = self.get_outer().or_else(|| self.get_builtins()) { - current.union(&outer.get_simple_trait_impls(t)) + current.union(&outer.get_simple_trait_impls(trait_)) } else { current } @@ -2401,21 +2402,17 @@ impl Context { } fn get_proj_candidates(&self, lhs: &Type, rhs: &Str) -> Set { - let allow_cast = true; #[allow(clippy::single_match)] match lhs { Type::FreeVar(fv) => { if let Some(sup) = fv.get_super() { - let insts = self.get_trait_impls(&sup); - let candidates = insts.into_iter().filter_map(move |inst| { - if self.supertype_of(&inst.sup_trait, &sup, allow_cast) { - self.eval_t_params(proj(inst.sub_type, rhs), self.level, &()) - .ok() - } else { - None - } - }); - return candidates.collect(); + if self.is_trait(&sup) { + return self.get_trait_proj_candidates(&sup, rhs); + } else { + return self + .eval_proj(sup, rhs.clone(), self.level, &()) + .map_or(set! {}, |t| set! {t}); + } } } _ => {} @@ -2423,6 +2420,20 @@ impl Context { set! {} } + fn get_trait_proj_candidates(&self, trait_: &Type, rhs: &Str) -> Set { + let allow_cast = true; + let impls = self.get_trait_impls(trait_); + let candidates = impls.into_iter().filter_map(move |inst| { + if self.supertype_of(&inst.sup_trait, trait_, allow_cast) { + self.eval_t_params(proj(inst.sub_type, rhs), self.level, &()) + .ok() + } else { + None + } + }); + candidates.collect() + } + pub(crate) fn is_class(&self, typ: &Type) -> bool { match typ { Type::And(_l, _r) => false, diff --git a/crates/erg_compiler/context/instantiate.rs b/crates/erg_compiler/context/instantiate.rs index a8a776ce..8c7c147a 100644 --- a/crates/erg_compiler/context/instantiate.rs +++ b/crates/erg_compiler/context/instantiate.rs @@ -1497,4 +1497,27 @@ impl Context { other => Ok(other), } } + + pub(crate) fn instantiate_dummy(&self, quantified: Type) -> Type { + match quantified { + FreeVar(fv) if fv.is_linked() => self.instantiate_dummy(fv.crack().clone()), + Quantified(quant) => { + let mut tmp_tv_cache = TyVarCache::new(self.level, self); + let ty = self + .instantiate_t_inner(*quant, &mut tmp_tv_cache, &()) + .unwrap(); + if cfg!(feature = "debug") && ty.has_qvar() { + panic!("{ty} has qvar") + } + ty + } + Refinement(refine) if refine.t.is_quantified_subr() => { + let quant = enum_unwrap!(*refine.t, Type::Quantified); + let mut tmp_tv_cache = TyVarCache::new(self.level, self); + self.instantiate_t_inner(*quant, &mut tmp_tv_cache, &()) + .unwrap() + } + other => unreachable!("{other}"), + } + } } diff --git a/crates/erg_compiler/context/mod.rs b/crates/erg_compiler/context/mod.rs index 499549bc..54fcbbc9 100644 --- a/crates/erg_compiler/context/mod.rs +++ b/crates/erg_compiler/context/mod.rs @@ -24,11 +24,10 @@ use erg_common::config::Input; use erg_common::dict::Dict; use erg_common::error::Location; use erg_common::impl_display_from_debug; -use erg_common::set::Set; use erg_common::traits::{Locational, Stream}; use erg_common::vis::Visibility; use erg_common::Str; -use erg_common::{fn_name, get_hash, log}; +use erg_common::{fmt_option, fn_name, get_hash, log}; use ast::{DefId, DefKind, VarName}; use erg_parser::ast; @@ -217,6 +216,26 @@ impl From for ContextKind { } } +impl fmt::Display for ContextKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Func => write!(f, "Func"), + Self::Proc => write!(f, "Proc"), + Self::Class => write!(f, "Class"), + Self::MethodDefs(trait_) => write!(f, "MethodDefs({})", fmt_option!(trait_)), + Self::PatchMethodDefs(type_) => write!(f, "PatchMethodDefs({type_})"), + Self::Trait => write!(f, "Trait"), + Self::StructuralTrait => write!(f, "StructuralTrait"), + Self::Patch(type_) => write!(f, "Patch({type_})"), + Self::StructuralPatch(type_) => write!(f, "StructuralPatch({type_})"), + Self::GluePatch(type_) => write!(f, "GluePatch({type_})"), + Self::Module => write!(f, "Module"), + Self::Instant => write!(f, "Instant"), + Self::Dummy => write!(f, "Dummy"), + } + } +} + impl ContextKind { pub const fn is_method_def(&self) -> bool { matches!(self, Self::MethodDefs(_)) @@ -314,10 +333,6 @@ pub struct Context { /// K: メソッド名, V: それを実装するパッチたち /// 提供メソッドはスコープごとに実装を切り替えることができる pub(crate) method_impl_patches: Dict>, - /// K: name of a trait, V: (type, monomorphised trait that the type implements) - /// K: トレイトの名前, V: (型, その型が実装する単相化トレイト) - /// e.g. { "Named": [(Type, Named), (Func, Named), ...], "Add": [(Nat, Add(Nat)), (Int, Add(Int)), ...], ... } - pub(crate) trait_impls: Dict>, /// stores declared names (not initialized) pub(crate) decls: Dict, /// for error reporting @@ -510,7 +525,6 @@ impl Context { method_to_traits: Dict::default(), method_to_classes: Dict::default(), method_impl_patches: Dict::default(), - trait_impls: Dict::default(), params: params_, decls: Dict::default(), future_defined_locals: Dict::default(), @@ -795,8 +809,13 @@ impl Context { } #[inline] - pub fn builtin_module>(name: S, cfg: ErgConfig, capacity: usize) -> Self { - Self::module(name.into(), cfg, None, capacity) + pub fn builtin_module>( + name: S, + cfg: ErgConfig, + shared: SharedCompilerResource, + capacity: usize, + ) -> Self { + Self::module(name.into(), cfg, Some(shared), capacity) } #[inline] @@ -979,20 +998,24 @@ impl Context { .collect() } - pub(crate) fn mod_cache(&self) -> Option<&SharedModuleCache> { - self.shared.as_ref().map(|shared| &shared.mod_cache) + pub(crate) fn mod_cache(&self) -> &SharedModuleCache { + &self.shared().mod_cache } - pub(crate) fn py_mod_cache(&self) -> Option<&SharedModuleCache> { - self.shared.as_ref().map(|shared| &shared.py_mod_cache) + pub(crate) fn py_mod_cache(&self) -> &SharedModuleCache { + &self.shared().py_mod_cache } - pub fn index(&self) -> Option<&crate::module::SharedModuleIndex> { - self.shared.as_ref().map(|shared| &shared.index) + pub fn index(&self) -> &crate::module::SharedModuleIndex { + &self.shared().index } - pub fn shared(&self) -> Option<&SharedCompilerResource> { - self.shared.as_ref() + pub fn trait_impls(&self) -> &crate::module::SharedTraitImpls { + &self.shared().trait_impls + } + + pub fn shared(&self) -> &SharedCompilerResource { + self.shared.as_ref().unwrap() } } diff --git a/crates/erg_compiler/context/register.rs b/crates/erg_compiler/context/register.rs index 09ac451f..aeff83c2 100644 --- a/crates/erg_compiler/context/register.rs +++ b/crates/erg_compiler/context/register.rs @@ -165,9 +165,7 @@ impl Context { py_name, self.absolutize(ident.name.loc()), ); - if let Some(shared) = self.shared() { - shared.index.register(&vi); - } + self.index().register(&vi); self.future_defined_locals.insert(ident.name.clone(), vi); Ok(()) } @@ -213,9 +211,7 @@ impl Context { py_name, self.absolutize(sig.ident.name.loc()), ); - if let Some(shared) = self.shared() { - shared.index.register(&vi); - } + self.index().register(&vi); if let Some(_decl) = self.decls.remove(name) { Err(TyCheckErrors::from(TyCheckError::duplicate_decl_error( self.cfg.input.clone(), @@ -403,9 +399,7 @@ impl Context { None, self.absolutize(name.loc()), ); - if let Some(shared) = self.shared() { - shared.index.register(&vi); - } + self.index().register(&vi); sig.vi = vi.clone(); self.params.push((Some(name.clone()), vi)); if errs.is_empty() { @@ -1076,9 +1070,7 @@ impl Context { None, self.absolutize(ident.name.loc()), ); - if let Some(shared) = self.shared() { - shared.index.register(&vi); - } + self.index().register(&vi); self.decls.insert(ident.name.clone(), vi); self.consts.insert(ident.name.clone(), other); Ok(()) @@ -1372,9 +1364,7 @@ impl Context { None, self.absolutize(name.loc()), ); - if let Some(shared) = self.shared() { - shared.index.register(&vi); - } + self.index().register(&vi); self.decls.insert(name.clone(), vi); self.consts .insert(name.clone(), ValueObj::Type(TypeObj::Builtin(t))); @@ -1421,17 +1411,15 @@ impl Context { None, self.absolutize(name.loc()), ); - if let Some(shared) = self.shared() { - shared.index.register(&vi); - } + self.index().register(&vi); self.decls.insert(name.clone(), vi); self.consts .insert(name.clone(), ValueObj::Type(TypeObj::Generated(gen))); for impl_trait in ctx.super_traits.iter() { - if let Some(impls) = self.trait_impls.get_mut(&impl_trait.qual_name()) { + if let Some(impls) = self.trait_impls().get_mut(&impl_trait.qual_name()) { impls.insert(TraitImpl::new(t.clone(), impl_trait.clone())); } else { - self.trait_impls.insert( + self.trait_impls().register( impl_trait.qual_name(), set![TraitImpl::new(t.clone(), impl_trait.clone())], ); @@ -1507,10 +1495,10 @@ impl Context { self.consts .insert(name.clone(), ValueObj::Type(TypeObj::Generated(gen))); for impl_trait in ctx.super_traits.iter() { - if let Some(impls) = self.trait_impls.get_mut(&impl_trait.qual_name()) { + if let Some(impls) = self.trait_impls().get_mut(&impl_trait.qual_name()) { impls.insert(TraitImpl::new(t.clone(), impl_trait.clone())); } else { - self.trait_impls.insert( + self.trait_impls().register( impl_trait.qual_name(), set![TraitImpl::new(t.clone(), impl_trait.clone())], ); @@ -1555,8 +1543,8 @@ impl Context { fn import_erg_mod(&self, mod_name: &Literal) -> CompileResult { let ValueObj::Str(__name__) = mod_name.value.clone() else { todo!("{mod_name}") }; - let mod_cache = self.mod_cache().unwrap(); - let py_mod_cache = self.py_mod_cache().unwrap(); + let mod_cache = self.mod_cache(); + let py_mod_cache = self.py_mod_cache(); let path = match Self::resolve_real_path(&self.cfg, Path::new(&__name__[..])) { Some(path) => path, None => { @@ -1687,9 +1675,9 @@ impl Context { mod_name.loc(), self.caused_by(), self.similar_builtin_erg_mod_name(&__name__) - .or_else(|| self.mod_cache().unwrap().get_similar_name(&__name__)), + .or_else(|| self.mod_cache().get_similar_name(&__name__)), self.similar_builtin_py_mod_name(&__name__) - .or_else(|| self.py_mod_cache().unwrap().get_similar_name(&__name__)), + .or_else(|| self.py_mod_cache().get_similar_name(&__name__)), ); Err(TyCheckErrors::from(err)) } @@ -1725,7 +1713,7 @@ impl Context { fn import_py_mod(&self, mod_name: &Literal) -> CompileResult { let ValueObj::Str(__name__) = mod_name.value.clone() else { todo!("{mod_name}") }; - let py_mod_cache = self.py_mod_cache().unwrap(); + let py_mod_cache = self.py_mod_cache(); let path = self.get_path(mod_name, __name__)?; if let Some(referrer) = self.cfg.input.path() { let graph = &self.shared.as_ref().unwrap().graph; @@ -1888,8 +1876,6 @@ impl Context { } pub fn inc_ref(&self, vi: &VarInfo, name: &L) { - self.index() - .unwrap() - .inc_ref(vi, self.absolutize(name.loc())); + self.index().inc_ref(vi, self.absolutize(name.loc())); } } diff --git a/crates/erg_compiler/context/test.rs b/crates/erg_compiler/context/test.rs index bc3ee57e..d46aa5b4 100644 --- a/crates/erg_compiler/context/test.rs +++ b/crates/erg_compiler/context/test.rs @@ -1,5 +1,6 @@ //! test module for `Context` use erg_common::set; +use erg_common::traits::StructuralEq; use erg_common::Str; use crate::ty::constructors::{func1, mono, mono_q, poly, refinement}; @@ -17,7 +18,7 @@ impl Context { panic!("variable not found: {varname}"); }; println!("{varname}: {}", vi.t); - if self.same_type_of(&vi.t, ty, false) { + if vi.t.structural_eq(ty) { Ok(()) } else { println!("{varname} is not the type of {ty}"); @@ -34,13 +35,22 @@ impl Context { Type::Int, set! { Predicate::eq(var, TyParam::value(1)) }, ); - if self.supertype_of(&lhs, &rhs, false) { + if self.supertype_of(&lhs, &rhs, true) { Ok(()) } else { Err(()) } } + pub fn test_quant_subtyping(&self) -> Result<(), ()> { + let t = crate::ty::constructors::type_q("T"); + let quant = func1(t.clone(), t).quantify(); + let subr = func1(Obj, Never); + assert!(!self.subtype_of(&quant, &subr, true)); + assert!(self.subtype_of(&subr, &quant, true)); + Ok(()) + } + pub fn test_resolve_trait_inner1(&self) -> Result<(), ()> { let name = Str::ever("Add"); let params = vec![TyParam::t(Nat)]; diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index ca8c6790..fe332aae 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -728,6 +728,20 @@ impl Context { Ok(()) } (Type::Subr(lsub), Type::Subr(rsub)) => { + lsub.non_default_params + .iter() + .zip(rsub.non_default_params.iter()) + .try_for_each(|(l, r)| { + // contravariant + self.sub_unify(r.typ(), l.typ(), loc, param_name) + })?; + lsub.var_params + .iter() + .zip(rsub.var_params.iter()) + .try_for_each(|(l, r)| { + // contravariant + self.sub_unify(r.typ(), l.typ(), loc, param_name) + })?; for lpt in lsub.default_params.iter() { if let Some(rpt) = rsub .default_params @@ -737,16 +751,9 @@ impl Context { // contravariant self.sub_unify(rpt.typ(), lpt.typ(), loc, param_name)?; } else { - todo!() + unreachable!() } } - lsub.non_default_params - .iter() - .zip(rsub.non_default_params.iter()) - .try_for_each(|(l, r)| { - // contravariant - self.sub_unify(r.typ(), l.typ(), loc, param_name) - })?; // covariant self.sub_unify(&lsub.return_t, &rsub.return_t, loc, param_name)?; Ok(()) @@ -781,9 +788,7 @@ impl Context { } })?; // covariant - if !lsub.return_t.is_generalized() { - self.sub_unify(&lsub.return_t, &rsub.return_t, loc, param_name)?; - } + self.sub_unify(&lsub.return_t, &rsub.return_t, loc, param_name)?; Ok(()) } (Type::Subr(lsub), Type::Quantified(rsub)) => { @@ -815,9 +820,7 @@ impl Context { } })?; // covariant - if !rsub.return_t.is_generalized() { - self.sub_unify(&lsub.return_t, &rsub.return_t, loc, param_name)?; - } + self.sub_unify(&lsub.return_t, &rsub.return_t, loc, param_name)?; Ok(()) } ( diff --git a/crates/erg_compiler/hir.rs b/crates/erg_compiler/hir.rs index 205476a2..efe288bf 100644 --- a/crates/erg_compiler/hir.rs +++ b/crates/erg_compiler/hir.rs @@ -1592,7 +1592,7 @@ pub struct NonDefaultParamSignature { impl NestedDisplay for NonDefaultParamSignature { fn fmt_nest(&self, f: &mut std::fmt::Formatter<'_>, _level: usize) -> std::fmt::Result { - write!(f, "{}", self.raw) + write!(f, "{}(: {})", self.raw, self.vi.t) } } diff --git a/crates/erg_compiler/lint.rs b/crates/erg_compiler/lint.rs index 1db89274..c984a24b 100644 --- a/crates/erg_compiler/lint.rs +++ b/crates/erg_compiler/lint.rs @@ -162,25 +162,23 @@ impl ASTLowerer { if mode == "eval" { return; } - if let Some(shared) = self.module.context.shared() { - for (referee, value) in shared.index.iter() { - let code = referee.code(); - let name = code.as_ref().map(|s| &s[..]).unwrap_or(""); - let name_is_auto = name == "_"; // || name.starts_with(['%']); - if value.referrers.is_empty() && value.vi.vis.is_private() && !name_is_auto { - let input = referee - .module - .as_ref() - .map_or(self.input().clone(), |path| path.as_path().into()); - let warn = LowerWarning::unused_warning( - input, - line!() as usize, - referee.loc, - name, - self.module.context.caused_by(), - ); - self.warns.push(warn); - } + for (referee, value) in self.module.context.index().iter() { + let code = referee.code(); + let name = code.as_ref().map(|s| &s[..]).unwrap_or(""); + let name_is_auto = name == "_"; // || name.starts_with(['%']); + if value.referrers.is_empty() && value.vi.vis.is_private() && !name_is_auto { + let input = referee + .module + .as_ref() + .map_or(self.input().clone(), |path| path.as_path().into()); + let warn = LowerWarning::unused_warning( + input, + line!() as usize, + referee.loc, + name, + self.module.context.caused_by(), + ); + self.warns.push(warn); } } } diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 9ce36a60..d8da44f5 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -1427,14 +1427,9 @@ impl ASTLowerer { ))); } let kind = ContextKind::MethodDefs(impl_trait.as_ref().map(|(t, _)| t.clone())); - let vis = if cfg!(feature = "py_compatible") { - Public - } else { - Private - }; self.module .context - .grow(&class.local_name(), kind, vis, None); + .grow(&class.local_name(), kind, hir_def.sig.vis(), None); for attr in methods.attrs.iter_mut() { match attr { ast::ClassAttr::Def(def) => { @@ -1636,10 +1631,15 @@ impl ASTLowerer { trait_loc: &impl Locational, ) -> LowerResult<()> { // TODO: polymorphic trait - if let Some(impls) = self.module.context.trait_impls.get_mut(&trait_.qual_name()) { + if let Some(impls) = self + .module + .context + .trait_impls() + .get_mut(&trait_.qual_name()) + { impls.insert(TraitImpl::new(class.clone(), trait_.clone())); } else { - self.module.context.trait_impls.insert( + self.module.context.trait_impls().register( trait_.qual_name(), set! {TraitImpl::new(class.clone(), trait_.clone())}, ); @@ -1746,6 +1746,7 @@ impl ASTLowerer { ) -> SingleLowerResult<()> { let allow_cast = true; if let Some((impl_trait, t_spec)) = impl_trait { + let impl_trait = impl_trait.normalize(); let mut unverified_names = self.module.context.locals.keys().collect::>(); if let Some(trait_obj) = self .module @@ -1763,8 +1764,10 @@ impl ASTLowerer { let def_t = &vi.t; // A(<: Add(R)), R -> A.Output // => A(<: Int), R -> A.Output - let replaced_decl_t = - decl_t.clone().replace(&impl_trait, class); + let replaced_decl_t = decl_t + .clone() + .replace(gen.typ(), &impl_trait) + .replace(&impl_trait, class); unverified_names.remove(name); // def_t must be subtype of decl_t if !self.module.context.supertype_of( @@ -1813,8 +1816,11 @@ impl ASTLowerer { self.module.context.get_var_kv(decl_name.inspect()) { let def_t = &vi.t; - let replaced_decl_t = - decl_vi.t.clone().replace(&impl_trait, class); + let replaced_decl_t = decl_vi + .t + .clone() + .replace(_typ, &impl_trait) + .replace(&impl_trait, class); unverified_names.remove(name); if !self.module.context.supertype_of( &replaced_decl_t, @@ -2042,10 +2048,10 @@ impl ASTLowerer { tasc.expr.loc(), self.module.context.caused_by(), switch_lang!( - "japanese" => "無効な型宣言です".to_string(), + "japanese" => "無効な型宣言です(左辺には記名型のみ使用出来ます)".to_string(), "simplified_chinese" => "无效的类型声明".to_string(), "traditional_chinese" => "無效的型宣告".to_string(), - "english" => "Invalid type declaration".to_string(), + "english" => "Invalid type declaration (currently only nominal types are allowed at LHS)".to_string(), ), None, ))); @@ -2073,14 +2079,7 @@ impl ASTLowerer { .sub_unify(&ident_vi.t, &spec_t, &ident, Some(ident.inspect()))?; } else { // if subtype ascription - let ctx = self - .module - .context - .get_singular_ctx_by_ident(&ident, &self.module.context.name)?; - // REVIEW: need to use subtype_of? - if ctx.super_traits.iter().all(|trait_| trait_ != &spec_t) - && ctx.super_classes.iter().all(|class| class != &spec_t) - { + if self.module.context.subtype_of(&ident_vi.t, &spec_t, true) { return Err(LowerErrors::from(LowerError::subtyping_error( self.cfg.input.clone(), line!() as usize, diff --git a/crates/erg_compiler/module/cache.rs b/crates/erg_compiler/module/cache.rs index 4a6b008d..f761d485 100644 --- a/crates/erg_compiler/module/cache.rs +++ b/crates/erg_compiler/module/cache.rs @@ -10,7 +10,7 @@ use erg_common::levenshtein::get_similar_name; use erg_common::shared::Shared; use erg_common::Str; -use crate::context::{Context, ModuleContext}; +use crate::context::ModuleContext; use crate::hir::HIR; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -157,10 +157,8 @@ impl fmt::Display for SharedModuleCache { } impl SharedModuleCache { - pub fn new(cfg: ErgConfig) -> Self { - let self_ = Self(Shared::new(ModuleCache::new())); - Context::init_builtins(cfg, &self_); - self_ + pub fn new() -> Self { + Self(Shared::new(ModuleCache::new())) } pub fn get(&self, path: &Q) -> Option<&ModuleEntry> diff --git a/crates/erg_compiler/module/global.rs b/crates/erg_compiler/module/global.rs index e377629b..86d73d93 100644 --- a/crates/erg_compiler/module/global.rs +++ b/crates/erg_compiler/module/global.rs @@ -1,7 +1,10 @@ use erg_common::config::ErgConfig; +use crate::context::Context; + use super::cache::SharedModuleCache; use super::graph::SharedModuleGraph; +use super::impls::SharedTraitImpls; use super::index::SharedModuleIndex; #[derive(Debug, Clone, Default)] @@ -10,16 +13,25 @@ pub struct SharedCompilerResource { pub py_mod_cache: SharedModuleCache, pub index: SharedModuleIndex, pub graph: SharedModuleGraph, + /// K: name of a trait, V: (type, monomorphised trait that the type implements) + /// K: トレイトの名前, V: (型, その型が実装する単相化トレイト) + /// e.g. { "Named": [(Type, Named), (Func, Named), ...], "Add": [(Nat, Add(Nat)), (Int, Add(Int)), ...], ... } + pub trait_impls: SharedTraitImpls, } impl SharedCompilerResource { + /// Initialize the shared compiler resource. + /// This API is normally called only once throughout the compilation phase. pub fn new(cfg: ErgConfig) -> Self { - Self { - mod_cache: SharedModuleCache::new(cfg.copy()), - py_mod_cache: SharedModuleCache::new(cfg), + let self_ = Self { + mod_cache: SharedModuleCache::new(), + py_mod_cache: SharedModuleCache::new(), index: SharedModuleIndex::new(), graph: SharedModuleGraph::new(), - } + trait_impls: SharedTraitImpls::new(), + }; + Context::init_builtins(cfg, self_.clone()); + self_ } pub fn clear_all(&self) { @@ -27,5 +39,6 @@ impl SharedCompilerResource { self.py_mod_cache.initialize(); self.index.initialize(); self.graph.initialize(); + self.trait_impls.initialize(); } } diff --git a/crates/erg_compiler/module/impls.rs b/crates/erg_compiler/module/impls.rs new file mode 100644 index 00000000..e5fb77f7 --- /dev/null +++ b/crates/erg_compiler/module/impls.rs @@ -0,0 +1,114 @@ +use std::borrow::Borrow; +use std::fmt; +use std::hash::Hash; + +use erg_common::dict::Dict; +use erg_common::set::Set; +use erg_common::shared::Shared; +use erg_common::Str; + +use crate::context::TraitImpl; + +/// Caches checked modules. +/// In addition to being queried here when re-imported, it is also used when linking +/// (Erg links all scripts defined in erg and outputs them to a single pyc file). +#[derive(Debug, Default)] +pub struct TraitImpls { + cache: Dict>, +} + +impl fmt::Display for TraitImpls { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "TraitImpls {{")?; + for (name, impls) in self.cache.iter() { + writeln!(f, "{name}: {impls}, ")?; + } + write!(f, "}}") + } +} + +impl TraitImpls { + pub fn new() -> Self { + Self { cache: Dict::new() } + } + + pub fn get(&self, path: &P) -> Option<&Set> + where + Str: Borrow

, + { + self.cache.get(path) + } + + pub fn get_mut(&mut self, path: &Q) -> Option<&mut Set> + where + Str: Borrow, + { + self.cache.get_mut(path) + } + + pub fn register(&mut self, name: Str, impls: Set) { + self.cache.insert(name, impls); + } + + pub fn remove(&mut self, path: &Q) -> Option> + where + Str: Borrow, + { + self.cache.remove(path) + } + + pub fn initialize(&mut self) { + self.cache.clear(); + } +} + +#[derive(Debug, Clone, Default)] +pub struct SharedTraitImpls(Shared); + +impl fmt::Display for SharedTraitImpls { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Shared{}", self.0) + } +} + +impl SharedTraitImpls { + pub fn new() -> Self { + Self(Shared::new(TraitImpls::new())) + } + + pub fn get(&self, path: &Q) -> Option<&Set> + where + Str: Borrow, + { + let ref_ = unsafe { self.0.as_ptr().as_ref().unwrap() }; + ref_.get(path) + } + + pub fn get_mut(&self, path: &Q) -> Option<&mut Set> + where + Str: Borrow, + { + let ref_ = unsafe { self.0.as_ptr().as_mut().unwrap() }; + ref_.get_mut(path) + } + + pub fn register(&self, name: Str, impls: Set) { + self.0.borrow_mut().register(name, impls); + } + + pub fn remove(&self, path: &Q) -> Option> + where + Str: Borrow, + { + self.0.borrow_mut().remove(path) + } + + pub fn keys(&self) -> impl Iterator { + let ref_ = unsafe { self.0.as_ptr().as_ref().unwrap() }; + ref_.cache.keys().cloned() + } + + pub fn initialize(&self) { + self.0.borrow_mut().initialize(); + } +} diff --git a/crates/erg_compiler/module/mod.rs b/crates/erg_compiler/module/mod.rs index dee5eddf..3b00c675 100644 --- a/crates/erg_compiler/module/mod.rs +++ b/crates/erg_compiler/module/mod.rs @@ -1,9 +1,11 @@ pub mod cache; pub mod global; pub mod graph; +pub mod impls; pub mod index; pub use cache::*; pub use global::*; pub use graph::*; +pub use impls::*; pub use index::*; diff --git a/crates/erg_compiler/tests/test.rs b/crates/erg_compiler/tests/test.rs index d7cc5a6d..ffafbe1d 100644 --- a/crates/erg_compiler/tests/test.rs +++ b/crates/erg_compiler/tests/test.rs @@ -61,12 +61,19 @@ fn test_infer_types() -> Result<(), ()> { } #[test] -fn test_subtyping() -> Result<(), ()> { +fn test_refinement_subtyping() -> Result<(), ()> { let context = Context::default_with_name(""); context.test_refinement_subtyping()?; Ok(()) } +#[test] +fn test_quant_subtyping() -> Result<(), ()> { + let context = Context::default_with_name(""); + context.test_quant_subtyping()?; + Ok(()) +} + #[test] fn test_instantiation_and_generalization() -> Result<(), ()> { let context = Context::default_with_name(""); diff --git a/crates/erg_compiler/ty/free.rs b/crates/erg_compiler/ty/free.rs index 1b7828f2..b0a75010 100644 --- a/crates/erg_compiler/ty/free.rs +++ b/crates/erg_compiler/ty/free.rs @@ -169,6 +169,7 @@ impl LimitedDisplay for Constraint { } impl Constraint { + /// :> Sub, <: Sup pub const fn new_sandwiched(sub: Type, sup: Type) -> Self { Self::Sandwiched { sub, sup } } diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index 794508cc..1b0b9275 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -1323,6 +1323,7 @@ impl Type { } pub fn quantify(self) -> Self { + debug_assert!(self.is_subr()); Self::Quantified(Box::new(self)) } @@ -2130,73 +2131,231 @@ impl Type { } pub fn replace(self, target: &Type, to: &Type) -> Type { - if &self == target { - return to.clone(); + let table = ReplaceTable::make(target, to); + table.replace(self) + } + + fn _replace(mut self, target: &Type, to: &Type) -> Type { + if self.structural_eq(target) { + self = to.clone(); } match self { - Self::FreeVar(fv) if fv.is_linked() => fv.crack().clone().replace(target, to), + Self::FreeVar(fv) if fv.is_linked() => fv.crack().clone()._replace(target, to), + Self::FreeVar(fv) => { + if let Some((sub, sup)) = fv.get_subsup() { + fv.forced_undoable_link(&sub); + let sub = sub._replace(target, to); + let sup = sup._replace(target, to); + fv.undo(); + fv.update_constraint(Constraint::new_sandwiched(sub, sup), true); + } else if let Some(ty) = fv.get_type() { + fv.update_constraint(Constraint::new_type_of(ty._replace(target, to)), true); + } + Self::FreeVar(fv) + } Self::Refinement(mut refine) => { - refine.t = Box::new(refine.t.replace(target, to)); + refine.t = Box::new(refine.t._replace(target, to)); Self::Refinement(refine) } Self::Record(mut rec) => { for v in rec.values_mut() { - *v = std::mem::take(v).replace(target, to); + *v = std::mem::take(v)._replace(target, to); } Self::Record(rec) } Self::Subr(mut subr) => { for nd in subr.non_default_params.iter_mut() { - *nd.typ_mut() = std::mem::take(nd.typ_mut()).replace(target, to); + *nd.typ_mut() = std::mem::take(nd.typ_mut())._replace(target, to); } if let Some(var) = subr.var_params.as_mut() { *var.as_mut().typ_mut() = - std::mem::take(var.as_mut().typ_mut()).replace(target, to); + std::mem::take(var.as_mut().typ_mut())._replace(target, to); } for d in subr.default_params.iter_mut() { - *d.typ_mut() = std::mem::take(d.typ_mut()).replace(target, to); + *d.typ_mut() = std::mem::take(d.typ_mut())._replace(target, to); } - subr.return_t = Box::new(subr.return_t.replace(target, to)); + subr.return_t = Box::new(subr.return_t._replace(target, to)); Self::Subr(subr) } Self::Callable { param_ts, return_t } => { let param_ts = param_ts .into_iter() - .map(|t| t.replace(target, to)) + .map(|t| t._replace(target, to)) .collect(); - let return_t = Box::new(return_t.replace(target, to)); + let return_t = Box::new(return_t._replace(target, to)); Self::Callable { param_ts, return_t } } - Self::Quantified(quant) => quant.replace(target, to).quantify(), + Self::Quantified(quant) => quant._replace(target, to).quantify(), Self::Poly { name, params } => { let params = params .into_iter() - .map(|tp| match tp { - TyParam::Type(t) => TyParam::t(t.replace(target, to)), - other => other, - }) + .map(|tp| tp.replace(target, to)) .collect(); Self::Poly { name, params } } - Self::Ref(t) => Self::Ref(Box::new(t.replace(target, to))), + Self::Ref(t) => Self::Ref(Box::new(t._replace(target, to))), Self::RefMut { before, after } => Self::RefMut { - before: Box::new(before.replace(target, to)), - after: after.map(|t| Box::new(t.replace(target, to))), + before: Box::new(before._replace(target, to)), + after: after.map(|t| Box::new(t._replace(target, to))), }, Self::And(l, r) => { - let l = l.replace(target, to); - let r = r.replace(target, to); + let l = l._replace(target, to); + let r = r._replace(target, to); Self::And(Box::new(l), Box::new(r)) } Self::Or(l, r) => { - let l = l.replace(target, to); - let r = r.replace(target, to); + let l = l._replace(target, to); + let r = r._replace(target, to); Self::Or(Box::new(l), Box::new(r)) } - Self::Not(ty) => Self::Not(Box::new(ty.replace(target, to))), + Self::Not(ty) => Self::Not(Box::new(ty._replace(target, to))), + Self::Proj { lhs, rhs } => lhs._replace(target, to).proj(rhs), + Self::ProjCall { + lhs, + attr_name, + args, + } => { + let args = args.into_iter().map(|tp| tp.replace(target, to)).collect(); + lhs.replace(target, to).proj_call(attr_name, args) + } other => other, } } + + /// TyParam::Value(ValueObj::Type(_)) => TyParam::Type + pub fn normalize(self) -> Self { + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().clone().normalize(), + Self::Poly { name, params } => { + let params = params.into_iter().map(|tp| tp.normalize()).collect(); + Self::Poly { name, params } + } + Self::Subr(mut subr) => { + for nd in subr.non_default_params.iter_mut() { + *nd.typ_mut() = std::mem::take(nd.typ_mut()).normalize(); + } + if let Some(var) = subr.var_params.as_mut() { + *var.as_mut().typ_mut() = std::mem::take(var.as_mut().typ_mut()).normalize(); + } + for d in subr.default_params.iter_mut() { + *d.typ_mut() = std::mem::take(d.typ_mut()).normalize(); + } + subr.return_t = Box::new(subr.return_t.normalize()); + Self::Subr(subr) + } + Self::Proj { lhs, rhs } => lhs.normalize().proj(rhs), + other => other, + } + } +} + +pub struct ReplaceTable<'t> { + rules: Vec<(&'t Type, &'t Type)>, +} + +impl<'t> ReplaceTable<'t> { + pub fn make(target: &'t Type, to: &'t Type) -> Self { + let mut self_ = ReplaceTable { rules: vec![] }; + self_.iterate(target, to); + self_ + } + + pub fn replace(&self, mut ty: Type) -> Type { + for (target, to) in self.rules.iter() { + log!(err "{target} /=> {to}"); + ty = ty._replace(target, to); + } + ty + } + + fn iterate(&mut self, target: &'t Type, to: &'t Type) { + match (target, to) { + ( + Type::Poly { name, params }, + Type::Poly { + name: name2, + params: params2, + }, + ) if name == name2 => { + for (t1, t2) in params.iter().zip(params2.iter()) { + self.iterate_tp(t1, t2); + } + } + (Type::Subr(lsub), Type::Subr(rsub)) => { + for (lnd, rnd) in lsub + .non_default_params + .iter() + .zip(rsub.non_default_params.iter()) + { + self.iterate(lnd.typ(), rnd.typ()); + } + for (lv, rv) in lsub.var_params.iter().zip(rsub.var_params.iter()) { + self.iterate(lv.typ(), rv.typ()); + } + for (ld, rd) in lsub.default_params.iter().zip(rsub.default_params.iter()) { + self.iterate(ld.typ(), rd.typ()); + } + self.iterate(lsub.return_t.as_ref(), rsub.return_t.as_ref()); + } + (Type::Quantified(quant), Type::Quantified(quant2)) => { + self.iterate(quant, quant2); + } + ( + Type::Proj { lhs, rhs }, + Type::Proj { + lhs: lhs2, + rhs: rhs2, + }, + ) if rhs == rhs2 => { + self.iterate(lhs, lhs2); + } + (Type::And(l, r), Type::And(l2, r2)) => { + self.iterate(l, l2); + self.iterate(r, r2); + } + (Type::Or(l, r), Type::Or(l2, r2)) => { + self.iterate(l, l2); + self.iterate(r, r2); + } + (Type::Not(t), Type::Not(t2)) => { + self.iterate(t, t2); + } + (Type::Ref(t), Type::Ref(t2)) => { + self.iterate(t, t2); + } + ( + Type::RefMut { before, after }, + Type::RefMut { + before: before2, + after: after2, + }, + ) => { + self.iterate(before, before2); + if let (Some(after), Some(after2)) = (after.as_ref(), after2.as_ref()) { + self.iterate(after, after2); + } + } + _ => {} + } + self.rules.push((target, to)); + } + + fn iterate_tp(&mut self, target: &'t TyParam, to: &'t TyParam) { + match (target, to) { + (TyParam::FreeVar(fv), to) if fv.is_linked() => self.iterate_tp(fv.unsafe_crack(), to), + (TyParam::Value(ValueObj::Type(target)), TyParam::Value(ValueObj::Type(to))) => { + self.iterate(target.typ(), to.typ()); + } + (TyParam::Type(t1), TyParam::Type(t2)) => self.iterate(t1, t2), + (TyParam::Value(ValueObj::Type(t1)), TyParam::Type(t2)) => { + self.iterate(t1.typ(), t2); + } + (TyParam::Type(t1), TyParam::Value(ValueObj::Type(t2))) => { + self.iterate(t1, t2.typ()); + } + _ => {} + } + } } /// Opcode used when Erg implements its own processor diff --git a/crates/erg_compiler/ty/typaram.rs b/crates/erg_compiler/ty/typaram.rs index d4b1f6be..afbf2041 100644 --- a/crates/erg_compiler/ty/typaram.rs +++ b/crates/erg_compiler/ty/typaram.rs @@ -719,6 +719,14 @@ impl TyParam { Self::Erased(Box::new(t)) } + pub fn proj_call(self, attr_name: Str, args: Vec) -> Type { + Type::ProjCall { + lhs: Box::new(self), + attr_name, + args, + } + } + // if self: Ratio, Succ(self) => self+ε pub fn succ(self) -> Self { Self::app("Succ", vec![self]) @@ -906,6 +914,26 @@ impl TyParam { _ => true, } } + + pub fn replace(self, target: &Type, to: &Type) -> TyParam { + match self { + TyParam::Value(ValueObj::Type(obj)) => { + TyParam::t(obj.typ().clone()._replace(target, to)) + } + TyParam::FreeVar(fv) if fv.is_linked() => fv.crack().clone().replace(target, to), + TyParam::Type(ty) => TyParam::t(ty._replace(target, to)), + self_ => self_, + } + } + + /// TyParam::Value(ValueObj::Type(_)) => TyParam::Type + pub fn normalize(self) -> TyParam { + match self { + TyParam::Value(ValueObj::Type(obj)) => TyParam::t(obj.typ().clone().normalize()), + TyParam::Type(t) => TyParam::t(t.normalize()), + other => other, + } + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] diff --git a/examples/impl.er b/examples/impl.er index 4d4bd11f..84acb059 100644 --- a/examples/impl.er +++ b/examples/impl.er @@ -7,7 +7,7 @@ Point|Point <: Add(Point)|. __add__ self, other: Point = Point.new(self::x + other::x, self::y + other::y) Point|Point <: Mul(Point)|. - Output = Nat + Output = Int __mul__ self, other: Point = self::x * other::x + self::y * other::y Point|Point <: Eq|. @@ -19,7 +19,7 @@ p = Point.new 1, 2 q = Point.new 3, 4 r: Point = p + q -s: Nat = p * q +s: Int = p * q assert s == 11 assert r == Point.new 4, 6 assert r.norm() == 52 diff --git a/tests/should_err/impl.er b/tests/should_err/impl.er new file mode 100644 index 00000000..f5bb099f --- /dev/null +++ b/tests/should_err/impl.er @@ -0,0 +1,5 @@ +impl = import "../should_ok/impl" + +c = impl.C.new() +print! c + 1 + diff --git a/tests/should_ok/impl.er b/tests/should_ok/impl.er new file mode 100644 index 00000000..72354811 --- /dev/null +++ b/tests/should_ok/impl.er @@ -0,0 +1,20 @@ +.C = Class() +.C|.C <: Eq|. + __eq__ self, other: .C = + _ = self + _ = other + True +.C|.C <: Add(Nat)|. + Output = Nat + __add__ self, other: Nat = + _ = self + other +.C|.C <: Add(Int)|. + Output = .C + __add__ self, other: Int = + _ = other + self + +c = .C.new() +assert c + 1 == 1 +assert c + -1 == c diff --git a/tests/test.rs b/tests/test.rs index 94a642af..602e78c0 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -196,6 +196,12 @@ fn exec_args() -> Result<(), ()> { expect_failure("tests/should_err/args.er", 16) } +/// This file compiles successfully, but causes a run-time error due to incomplete method dispatching +#[test] +fn exec_tests_impl() -> Result<(), ()> { + expect_end_with("tests/should_ok/impl.er", 1) +} + #[test] fn exec_infer_union_array() -> Result<(), ()> { expect_failure("tests/should_err/infer_union_array.er", 1)