From 968d3b5d2c9b1572d628a0540895ae1b18d29fed Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 28 Oct 2022 18:03:35 +0900 Subject: [PATCH] Fix trait implementation check --- compiler/erg_compiler/context/compare.rs | 21 +---- compiler/erg_compiler/context/eval.rs | 3 +- compiler/erg_compiler/context/inquire.rs | 15 ++-- compiler/erg_compiler/context/instantiate.rs | 7 ++ compiler/erg_compiler/context/register.rs | 3 +- compiler/erg_compiler/lower.rs | 90 ++++++++++++-------- compiler/erg_compiler/ty/mod.rs | 73 ++++++++++++++++ compiler/erg_parser/parse.rs | 2 +- 8 files changed, 147 insertions(+), 67 deletions(-) diff --git a/compiler/erg_compiler/context/compare.rs b/compiler/erg_compiler/context/compare.rs index 2e79e7b4..a2212776 100644 --- a/compiler/erg_compiler/context/compare.rs +++ b/compiler/erg_compiler/context/compare.rs @@ -175,27 +175,7 @@ impl Context { | (Float | Ratio, Ratio) | (Float, Float) => (Absolutely, true), (Type, ClassType | TraitType) => (Absolutely, true), - (Type, Record(rec)) => ( - Absolutely, - rec.iter().all(|(_, attr)| self.supertype_of(&Type, attr)), - ), (Type::Uninited, _) | (_, Type::Uninited) => panic!("used an uninited type variable"), - (Type, Subr(subr)) => ( - Absolutely, - subr.non_default_params - .iter() - .all(|pt| self.supertype_of(&Type, pt.typ())) - && subr - .default_params - .iter() - .all(|pt| self.supertype_of(&Type, pt.typ())) - && subr - .var_params - .as_ref() - .map(|va| self.supertype_of(&Type, va.typ())) - .unwrap_or(true) - && self.supertype_of(&Type, &subr.return_t), - ), ( Type::Mono(n), Subr(SubrType { @@ -461,6 +441,7 @@ impl Context { } true } + (Type, Subr(subr)) => self.supertype_of(&Type, &subr.return_t), (Type, Poly { name, params }) | (Poly { name, params }, Type) if &name[..] == "Array" || &name[..] == "Set" => { diff --git a/compiler/erg_compiler/context/eval.rs b/compiler/erg_compiler/context/eval.rs index f8a72df0..a3f4c406 100644 --- a/compiler/erg_compiler/context/eval.rs +++ b/compiler/erg_compiler/context/eval.rs @@ -259,7 +259,8 @@ impl Context { Signature::Var(_) => None, }; // TODO: set params - self.grow(__name__, ContextKind::Instant, vis, tv_cache); + let kind = ContextKind::from(def.def_kind()); + self.grow(__name__, kind, vis, tv_cache); let obj = self.eval_const_block(&def.body.block).map_err(|e| { self.pop(); e diff --git a/compiler/erg_compiler/context/inquire.rs b/compiler/erg_compiler/context/inquire.rs index 36b427ea..fe304d72 100644 --- a/compiler/erg_compiler/context/inquire.rs +++ b/compiler/erg_compiler/context/inquire.rs @@ -1827,18 +1827,19 @@ impl Context { } } - /// FIXME: if trait, returns a freevar + // TODO: poly type pub(crate) fn rec_get_self_t(&self) -> Option { if self.kind.is_method_def() || self.kind.is_type() { - // TODO: poly type - let name = self.name.split(&[':', '.']).last().unwrap(); - // let mono_t = mono(self.path(), Str::rc(name)); - if let Some((t, _)) = self.rec_get_type(name) { + // let name = self.name.split(&[':', '.']).last().unwrap(); + /*if let Some((t, _)) = self.rec_get_type(name) { + log!("{t}"); Some(t.clone()) } else { + log!("none"); None - } - } else if let Some(outer) = self.get_outer().or_else(|| self.get_builtins()) { + }*/ + Some(mono(self.name.clone())) + } else if let Some(outer) = self.get_outer() { outer.rec_get_self_t() } else { None diff --git a/compiler/erg_compiler/context/instantiate.rs b/compiler/erg_compiler/context/instantiate.rs index 2cf42ca3..4fce1a8c 100644 --- a/compiler/erg_compiler/context/instantiate.rs +++ b/compiler/erg_compiler/context/instantiate.rs @@ -412,6 +412,13 @@ impl Context { let t = self.instantiate_const_expr_as_type(&first.expr, None, tmp_tv_cache)?; Ok(ref_mut(t, None)) } + "Self" => self.rec_get_self_t().ok_or_else(|| { + TyCheckErrors::from(TyCheckError::unreachable( + self.cfg.input.clone(), + erg_common::fn_name_full!(), + line!(), + )) + }), other if simple.args.is_empty() => { if let Some(t) = tmp_tv_cache.get_tyvar(other) { return Ok(t.clone()); diff --git a/compiler/erg_compiler/context/register.rs b/compiler/erg_compiler/context/register.rs index 9d4e82c7..61041c89 100644 --- a/compiler/erg_compiler/context/register.rs +++ b/compiler/erg_compiler/context/register.rs @@ -582,7 +582,8 @@ impl Context { } } ast::Signature::Var(sig) if sig.is_const() => { - self.grow(__name__, ContextKind::Instant, sig.vis(), None); + let kind = ContextKind::from(def.def_kind()); + self.grow(__name__, kind, sig.vis(), None); let (obj, const_t) = match self.eval_const_block(&def.body.block) { Ok(obj) => (obj.clone(), v_enum(set! {obj})), Err(e) => { diff --git a/compiler/erg_compiler/lower.rs b/compiler/erg_compiler/lower.rs index 5e3c30c0..d5e43e9b 100644 --- a/compiler/erg_compiler/lower.rs +++ b/compiler/erg_compiler/lower.rs @@ -1065,6 +1065,10 @@ impl ASTLowerer { None, ), }; + // assume the class has implemented the trait, regardless of whether the implementation is correct + if let Some((trait_, trait_loc)) = &impl_trait { + self.register_trait_impl(&class, trait_, *trait_loc)?; + } if let Some(class_root) = self.ctx.get_nominal_type_ctx(&class) { if !class_root.kind.is_class() { return Err(LowerErrors::from(LowerError::method_definition_error( @@ -1114,7 +1118,6 @@ impl ASTLowerer { } if let Some((trait_, _)) = &impl_trait { self.check_override(&class, Some(trait_)); - self.register_trait_impl(&class, trait_); } else { self.check_override(&class, None); } @@ -1151,6 +1154,42 @@ impl ASTLowerer { )) } + fn register_trait_impl( + &mut self, + class: &Type, + trait_: &Type, + trait_loc: Location, + ) -> LowerResult<()> { + // TODO: polymorphic trait + if let Some(impls) = self.ctx.trait_impls.get_mut(&trait_.qual_name()) { + impls.insert(TypeRelationInstance::new(class.clone(), trait_.clone())); + } else { + self.ctx.trait_impls.insert( + trait_.qual_name(), + set! {TypeRelationInstance::new(class.clone(), trait_.clone())}, + ); + } + let trait_ctx = if let Some(trait_ctx) = self.ctx.get_nominal_type_ctx(trait_) { + trait_ctx.clone() + } else { + // TODO: maybe parameters are wrong + return Err(LowerErrors::from(LowerError::no_var_error( + self.cfg.input.clone(), + line!() as usize, + trait_loc, + self.ctx.caused_by(), + &trait_.local_name(), + None, + ))); + }; + let (_, class_ctx) = self + .ctx + .get_mut_nominal_type_ctx(class) + .unwrap_or_else(|| todo!("{class} not found")); + class_ctx.register_supertrait(trait_.clone(), &trait_ctx); + Ok(()) + } + /// HACK: Cannot be methodized this because `&self` has been taken immediately before. fn check_inheritable( cfg: &ErgConfig, @@ -1224,35 +1263,22 @@ impl ASTLowerer { class: &Type, ) -> SingleLowerResult<()> { if let Some((impl_trait, loc)) = impl_trait { - // assume the class has implemented the trait, regardless of whether the implementation is correct - let trait_ctx = if let Some(trait_ctx) = self.ctx.get_nominal_type_ctx(&impl_trait) { - trait_ctx.clone() - } else { - // TODO: maybe parameters are wrong - return Err(LowerError::no_var_error( - self.cfg.input.clone(), - line!() as usize, - loc, - self.ctx.caused_by(), - &impl_trait.local_name(), - None, - )); - }; - let (_, class_ctx) = self - .ctx - .get_mut_nominal_type_ctx(class) - .unwrap_or_else(|| todo!("{class} not found")); - class_ctx.register_supertrait(impl_trait.clone(), &trait_ctx); let mut unverified_names = self.ctx.locals.keys().collect::>(); if let Some(trait_obj) = self.ctx.rec_get_const_obj(&impl_trait.local_name()) { if let ValueObj::Type(typ) = trait_obj { match typ { TypeObj::Generated(gen) => match gen.require_or_sup().unwrap().typ() { Type::Record(attrs) => { - for (field, field_typ) in attrs.iter() { + for (field, decl_t) in attrs.iter() { if let Some((name, vi)) = self.ctx.get_local_kv(&field.symbol) { + let def_t = &vi.t; + // A(<: Add(R)), R -> A.Output + // => A(<: Int), R -> A.Output + let replaced_decl_t = + decl_t.clone().replace(&impl_trait, class); unverified_names.remove(name); - if !self.ctx.supertype_of(field_typ, &vi.t) { + // def_t must be subtype of decl_t + if !self.ctx.supertype_of(&replaced_decl_t, def_t) { self.errs.push(LowerError::trait_member_type_error( self.cfg.input.clone(), line!() as usize, @@ -1260,7 +1286,7 @@ impl ASTLowerer { self.ctx.caused_by(), name.inspect(), &impl_trait, - field_typ, + decl_t, &vi.t, None, )); @@ -1285,8 +1311,11 @@ impl ASTLowerer { for (decl_name, decl_vi) in ctx.decls.iter() { if let Some((name, vi)) = self.ctx.get_local_kv(decl_name.inspect()) { + let def_t = &vi.t; + let replaced_decl_t = + decl_vi.t.clone().replace(&impl_trait, class); unverified_names.remove(name); - if !self.ctx.supertype_of(&decl_vi.t, &vi.t) { + if !self.ctx.supertype_of(&replaced_decl_t, def_t) { self.errs.push(LowerError::trait_member_type_error( self.cfg.input.clone(), line!() as usize, @@ -1351,19 +1380,6 @@ impl ASTLowerer { Ok(()) } - fn register_trait_impl(&mut self, class: &Type, trait_: &Type) { - let trait_impls = &mut self.ctx.outer.as_mut().unwrap().trait_impls; - // TODO: polymorphic trait - if let Some(impls) = trait_impls.get_mut(&trait_.qual_name()) { - impls.insert(TypeRelationInstance::new(class.clone(), trait_.clone())); - } else { - trait_impls.insert( - trait_.qual_name(), - set! {TypeRelationInstance::new(class.clone(), trait_.clone())}, - ); - } - } - fn check_collision_and_push(&mut self, class: Type) { let methods = self.ctx.pop(); let (_, class_root) = self diff --git a/compiler/erg_compiler/ty/mod.rs b/compiler/erg_compiler/ty/mod.rs index 09df63d4..3b64ca3b 100644 --- a/compiler/erg_compiler/ty/mod.rs +++ b/compiler/erg_compiler/ty/mod.rs @@ -2167,6 +2167,79 @@ impl Type { other => other.clone(), } } + + pub fn replace(self, target: &Type, to: &Type) -> Type { + if &self == target { + return to.clone(); + } + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().clone().replace(target, to), + Self::Refinement(mut refine) => { + refine.t = Box::new(refine.t.replace(target, to)); + Self::Refinement(refine) + } + Self::Record(mut rec) => { + for v in rec.values_mut() { + *v = std::mem::take(v).replace(target, to); + } + Self::Record(rec) + } + Self::Subr(mut subr) => { + for nd in subr.non_default_params.iter_mut() { + *nd.typ_mut() = std::mem::take(nd.typ_mut()).replace(target, to); + } + if let Some(var) = subr.var_params.as_mut() { + *var.as_mut().typ_mut() = + std::mem::take(var.as_mut().typ_mut()).replace(target, to); + } + for d in subr.default_params.iter_mut() { + *d.typ_mut() = std::mem::take(d.typ_mut()).replace(target, to); + } + subr.return_t = Box::new(subr.return_t.replace(target, to)); + Self::Subr(subr) + } + Self::Callable { param_ts, return_t } => { + let param_ts = param_ts + .into_iter() + .map(|t| t.replace(target, to)) + .collect(); + let return_t = Box::new(return_t.replace(target, to)); + Self::Callable { param_ts, return_t } + } + Self::Quantified(quant) => quant.replace(target, to).quantify(), + Self::Poly { name, params } => { + let params = params + .into_iter() + .map(|tp| match tp { + TyParam::Type(t) => TyParam::t(t.replace(target, to)), + other => other, + }) + .collect(); + Self::Poly { name, params } + } + Self::Ref(t) => Self::Ref(Box::new(t.replace(target, to))), + Self::RefMut { before, after } => Self::RefMut { + before: Box::new(before.replace(target, to)), + after: after.map(|t| Box::new(t.replace(target, to))), + }, + Self::And(l, r) => { + let l = l.replace(target, to); + let r = r.replace(target, to); + Self::And(Box::new(l), Box::new(r)) + } + Self::Or(l, r) => { + let l = l.replace(target, to); + let r = r.replace(target, to); + Self::Or(Box::new(l), Box::new(r)) + } + Self::Not(l, r) => { + let l = l.replace(target, to); + let r = r.replace(target, to); + Self::Not(Box::new(l), Box::new(r)) + } + other => other, + } + } } /// バイトコード命令で、in-place型付けをするオブジェクト diff --git a/compiler/erg_parser/parse.rs b/compiler/erg_parser/parse.rs index ac89ca19..565b0e48 100644 --- a/compiler/erg_parser/parse.rs +++ b/compiler/erg_parser/parse.rs @@ -2832,7 +2832,7 @@ impl Parser { ) -> ParseResult { debug_call_info!(self); let sig = self - .convert_rhs_to_param(*tasc.expr, true) + .convert_rhs_to_param(Expr::TypeAsc(tasc), true) .map_err(|_| self.stack_dec())?; self.level -= 1; Ok(LambdaSignature::new(