Fix trait implementation check

This commit is contained in:
Shunsuke Shibayama 2022-10-28 18:03:35 +09:00
parent 3d35db4e3b
commit 968d3b5d2c
8 changed files with 147 additions and 67 deletions

View file

@ -175,27 +175,7 @@ impl Context {
| (Float | Ratio, Ratio) | (Float | Ratio, Ratio)
| (Float, Float) => (Absolutely, true), | (Float, Float) => (Absolutely, true),
(Type, ClassType | TraitType) => (Absolutely, true), (Type, ClassType | TraitType) => (Absolutely, true),
(Type, Record(rec)) => (
Absolutely,
rec.iter().all(|(_, attr)| self.supertype_of(&Type, attr)),
),
(Type::Uninited, _) | (_, Type::Uninited) => panic!("used an uninited type variable"), (Type::Uninited, _) | (_, Type::Uninited) => panic!("used an uninited type variable"),
(Type, Subr(subr)) => (
Absolutely,
subr.non_default_params
.iter()
.all(|pt| self.supertype_of(&Type, pt.typ()))
&& subr
.default_params
.iter()
.all(|pt| self.supertype_of(&Type, pt.typ()))
&& subr
.var_params
.as_ref()
.map(|va| self.supertype_of(&Type, va.typ()))
.unwrap_or(true)
&& self.supertype_of(&Type, &subr.return_t),
),
( (
Type::Mono(n), Type::Mono(n),
Subr(SubrType { Subr(SubrType {
@ -461,6 +441,7 @@ impl Context {
} }
true true
} }
(Type, Subr(subr)) => self.supertype_of(&Type, &subr.return_t),
(Type, Poly { name, params }) | (Poly { name, params }, Type) (Type, Poly { name, params }) | (Poly { name, params }, Type)
if &name[..] == "Array" || &name[..] == "Set" => if &name[..] == "Array" || &name[..] == "Set" =>
{ {

View file

@ -259,7 +259,8 @@ impl Context {
Signature::Var(_) => None, Signature::Var(_) => None,
}; };
// TODO: set params // TODO: set params
self.grow(__name__, ContextKind::Instant, vis, tv_cache); let kind = ContextKind::from(def.def_kind());
self.grow(__name__, kind, vis, tv_cache);
let obj = self.eval_const_block(&def.body.block).map_err(|e| { let obj = self.eval_const_block(&def.body.block).map_err(|e| {
self.pop(); self.pop();
e e

View file

@ -1827,18 +1827,19 @@ impl Context {
} }
} }
/// FIXME: if trait, returns a freevar // TODO: poly type
pub(crate) fn rec_get_self_t(&self) -> Option<Type> { pub(crate) fn rec_get_self_t(&self) -> Option<Type> {
if self.kind.is_method_def() || self.kind.is_type() { if self.kind.is_method_def() || self.kind.is_type() {
// TODO: poly type // let name = self.name.split(&[':', '.']).last().unwrap();
let name = self.name.split(&[':', '.']).last().unwrap(); /*if let Some((t, _)) = self.rec_get_type(name) {
// let mono_t = mono(self.path(), Str::rc(name)); log!("{t}");
if let Some((t, _)) = self.rec_get_type(name) {
Some(t.clone()) Some(t.clone())
} else { } else {
log!("none");
None None
} }*/
} else if let Some(outer) = self.get_outer().or_else(|| self.get_builtins()) { Some(mono(self.name.clone()))
} else if let Some(outer) = self.get_outer() {
outer.rec_get_self_t() outer.rec_get_self_t()
} else { } else {
None None

View file

@ -412,6 +412,13 @@ impl Context {
let t = self.instantiate_const_expr_as_type(&first.expr, None, tmp_tv_cache)?; let t = self.instantiate_const_expr_as_type(&first.expr, None, tmp_tv_cache)?;
Ok(ref_mut(t, None)) Ok(ref_mut(t, None))
} }
"Self" => self.rec_get_self_t().ok_or_else(|| {
TyCheckErrors::from(TyCheckError::unreachable(
self.cfg.input.clone(),
erg_common::fn_name_full!(),
line!(),
))
}),
other if simple.args.is_empty() => { other if simple.args.is_empty() => {
if let Some(t) = tmp_tv_cache.get_tyvar(other) { if let Some(t) = tmp_tv_cache.get_tyvar(other) {
return Ok(t.clone()); return Ok(t.clone());

View file

@ -582,7 +582,8 @@ impl Context {
} }
} }
ast::Signature::Var(sig) if sig.is_const() => { ast::Signature::Var(sig) if sig.is_const() => {
self.grow(__name__, ContextKind::Instant, sig.vis(), None); let kind = ContextKind::from(def.def_kind());
self.grow(__name__, kind, sig.vis(), None);
let (obj, const_t) = match self.eval_const_block(&def.body.block) { let (obj, const_t) = match self.eval_const_block(&def.body.block) {
Ok(obj) => (obj.clone(), v_enum(set! {obj})), Ok(obj) => (obj.clone(), v_enum(set! {obj})),
Err(e) => { Err(e) => {

View file

@ -1065,6 +1065,10 @@ impl ASTLowerer {
None, None,
), ),
}; };
// assume the class has implemented the trait, regardless of whether the implementation is correct
if let Some((trait_, trait_loc)) = &impl_trait {
self.register_trait_impl(&class, trait_, *trait_loc)?;
}
if let Some(class_root) = self.ctx.get_nominal_type_ctx(&class) { if let Some(class_root) = self.ctx.get_nominal_type_ctx(&class) {
if !class_root.kind.is_class() { if !class_root.kind.is_class() {
return Err(LowerErrors::from(LowerError::method_definition_error( return Err(LowerErrors::from(LowerError::method_definition_error(
@ -1114,7 +1118,6 @@ impl ASTLowerer {
} }
if let Some((trait_, _)) = &impl_trait { if let Some((trait_, _)) = &impl_trait {
self.check_override(&class, Some(trait_)); self.check_override(&class, Some(trait_));
self.register_trait_impl(&class, trait_);
} else { } else {
self.check_override(&class, None); self.check_override(&class, None);
} }
@ -1151,6 +1154,42 @@ impl ASTLowerer {
)) ))
} }
fn register_trait_impl(
&mut self,
class: &Type,
trait_: &Type,
trait_loc: Location,
) -> LowerResult<()> {
// TODO: polymorphic trait
if let Some(impls) = self.ctx.trait_impls.get_mut(&trait_.qual_name()) {
impls.insert(TypeRelationInstance::new(class.clone(), trait_.clone()));
} else {
self.ctx.trait_impls.insert(
trait_.qual_name(),
set! {TypeRelationInstance::new(class.clone(), trait_.clone())},
);
}
let trait_ctx = if let Some(trait_ctx) = self.ctx.get_nominal_type_ctx(trait_) {
trait_ctx.clone()
} else {
// TODO: maybe parameters are wrong
return Err(LowerErrors::from(LowerError::no_var_error(
self.cfg.input.clone(),
line!() as usize,
trait_loc,
self.ctx.caused_by(),
&trait_.local_name(),
None,
)));
};
let (_, class_ctx) = self
.ctx
.get_mut_nominal_type_ctx(class)
.unwrap_or_else(|| todo!("{class} not found"));
class_ctx.register_supertrait(trait_.clone(), &trait_ctx);
Ok(())
}
/// HACK: Cannot be methodized this because `&self` has been taken immediately before. /// HACK: Cannot be methodized this because `&self` has been taken immediately before.
fn check_inheritable( fn check_inheritable(
cfg: &ErgConfig, cfg: &ErgConfig,
@ -1224,35 +1263,22 @@ impl ASTLowerer {
class: &Type, class: &Type,
) -> SingleLowerResult<()> { ) -> SingleLowerResult<()> {
if let Some((impl_trait, loc)) = impl_trait { if let Some((impl_trait, loc)) = impl_trait {
// assume the class has implemented the trait, regardless of whether the implementation is correct
let trait_ctx = if let Some(trait_ctx) = self.ctx.get_nominal_type_ctx(&impl_trait) {
trait_ctx.clone()
} else {
// TODO: maybe parameters are wrong
return Err(LowerError::no_var_error(
self.cfg.input.clone(),
line!() as usize,
loc,
self.ctx.caused_by(),
&impl_trait.local_name(),
None,
));
};
let (_, class_ctx) = self
.ctx
.get_mut_nominal_type_ctx(class)
.unwrap_or_else(|| todo!("{class} not found"));
class_ctx.register_supertrait(impl_trait.clone(), &trait_ctx);
let mut unverified_names = self.ctx.locals.keys().collect::<Set<_>>(); let mut unverified_names = self.ctx.locals.keys().collect::<Set<_>>();
if let Some(trait_obj) = self.ctx.rec_get_const_obj(&impl_trait.local_name()) { if let Some(trait_obj) = self.ctx.rec_get_const_obj(&impl_trait.local_name()) {
if let ValueObj::Type(typ) = trait_obj { if let ValueObj::Type(typ) = trait_obj {
match typ { match typ {
TypeObj::Generated(gen) => match gen.require_or_sup().unwrap().typ() { TypeObj::Generated(gen) => match gen.require_or_sup().unwrap().typ() {
Type::Record(attrs) => { Type::Record(attrs) => {
for (field, field_typ) in attrs.iter() { for (field, decl_t) in attrs.iter() {
if let Some((name, vi)) = self.ctx.get_local_kv(&field.symbol) { if let Some((name, vi)) = self.ctx.get_local_kv(&field.symbol) {
let def_t = &vi.t;
// A(<: Add(R)), R -> A.Output
// => A(<: Int), R -> A.Output
let replaced_decl_t =
decl_t.clone().replace(&impl_trait, class);
unverified_names.remove(name); unverified_names.remove(name);
if !self.ctx.supertype_of(field_typ, &vi.t) { // def_t must be subtype of decl_t
if !self.ctx.supertype_of(&replaced_decl_t, def_t) {
self.errs.push(LowerError::trait_member_type_error( self.errs.push(LowerError::trait_member_type_error(
self.cfg.input.clone(), self.cfg.input.clone(),
line!() as usize, line!() as usize,
@ -1260,7 +1286,7 @@ impl ASTLowerer {
self.ctx.caused_by(), self.ctx.caused_by(),
name.inspect(), name.inspect(),
&impl_trait, &impl_trait,
field_typ, decl_t,
&vi.t, &vi.t,
None, None,
)); ));
@ -1285,8 +1311,11 @@ impl ASTLowerer {
for (decl_name, decl_vi) in ctx.decls.iter() { for (decl_name, decl_vi) in ctx.decls.iter() {
if let Some((name, vi)) = self.ctx.get_local_kv(decl_name.inspect()) if let Some((name, vi)) = self.ctx.get_local_kv(decl_name.inspect())
{ {
let def_t = &vi.t;
let replaced_decl_t =
decl_vi.t.clone().replace(&impl_trait, class);
unverified_names.remove(name); unverified_names.remove(name);
if !self.ctx.supertype_of(&decl_vi.t, &vi.t) { if !self.ctx.supertype_of(&replaced_decl_t, def_t) {
self.errs.push(LowerError::trait_member_type_error( self.errs.push(LowerError::trait_member_type_error(
self.cfg.input.clone(), self.cfg.input.clone(),
line!() as usize, line!() as usize,
@ -1351,19 +1380,6 @@ impl ASTLowerer {
Ok(()) Ok(())
} }
fn register_trait_impl(&mut self, class: &Type, trait_: &Type) {
let trait_impls = &mut self.ctx.outer.as_mut().unwrap().trait_impls;
// TODO: polymorphic trait
if let Some(impls) = trait_impls.get_mut(&trait_.qual_name()) {
impls.insert(TypeRelationInstance::new(class.clone(), trait_.clone()));
} else {
trait_impls.insert(
trait_.qual_name(),
set! {TypeRelationInstance::new(class.clone(), trait_.clone())},
);
}
}
fn check_collision_and_push(&mut self, class: Type) { fn check_collision_and_push(&mut self, class: Type) {
let methods = self.ctx.pop(); let methods = self.ctx.pop();
let (_, class_root) = self let (_, class_root) = self

View file

@ -2167,6 +2167,79 @@ impl Type {
other => other.clone(), other => other.clone(),
} }
} }
pub fn replace(self, target: &Type, to: &Type) -> Type {
if &self == target {
return to.clone();
}
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().clone().replace(target, to),
Self::Refinement(mut refine) => {
refine.t = Box::new(refine.t.replace(target, to));
Self::Refinement(refine)
}
Self::Record(mut rec) => {
for v in rec.values_mut() {
*v = std::mem::take(v).replace(target, to);
}
Self::Record(rec)
}
Self::Subr(mut subr) => {
for nd in subr.non_default_params.iter_mut() {
*nd.typ_mut() = std::mem::take(nd.typ_mut()).replace(target, to);
}
if let Some(var) = subr.var_params.as_mut() {
*var.as_mut().typ_mut() =
std::mem::take(var.as_mut().typ_mut()).replace(target, to);
}
for d in subr.default_params.iter_mut() {
*d.typ_mut() = std::mem::take(d.typ_mut()).replace(target, to);
}
subr.return_t = Box::new(subr.return_t.replace(target, to));
Self::Subr(subr)
}
Self::Callable { param_ts, return_t } => {
let param_ts = param_ts
.into_iter()
.map(|t| t.replace(target, to))
.collect();
let return_t = Box::new(return_t.replace(target, to));
Self::Callable { param_ts, return_t }
}
Self::Quantified(quant) => quant.replace(target, to).quantify(),
Self::Poly { name, params } => {
let params = params
.into_iter()
.map(|tp| match tp {
TyParam::Type(t) => TyParam::t(t.replace(target, to)),
other => other,
})
.collect();
Self::Poly { name, params }
}
Self::Ref(t) => Self::Ref(Box::new(t.replace(target, to))),
Self::RefMut { before, after } => Self::RefMut {
before: Box::new(before.replace(target, to)),
after: after.map(|t| Box::new(t.replace(target, to))),
},
Self::And(l, r) => {
let l = l.replace(target, to);
let r = r.replace(target, to);
Self::And(Box::new(l), Box::new(r))
}
Self::Or(l, r) => {
let l = l.replace(target, to);
let r = r.replace(target, to);
Self::Or(Box::new(l), Box::new(r))
}
Self::Not(l, r) => {
let l = l.replace(target, to);
let r = r.replace(target, to);
Self::Not(Box::new(l), Box::new(r))
}
other => other,
}
}
} }
/// バイトコード命令で、in-place型付けをするオブジェクト /// バイトコード命令で、in-place型付けをするオブジェクト

View file

@ -2832,7 +2832,7 @@ impl Parser {
) -> ParseResult<LambdaSignature> { ) -> ParseResult<LambdaSignature> {
debug_call_info!(self); debug_call_info!(self);
let sig = self let sig = self
.convert_rhs_to_param(*tasc.expr, true) .convert_rhs_to_param(Expr::TypeAsc(tasc), true)
.map_err(|_| self.stack_dec())?; .map_err(|_| self.stack_dec())?;
self.level -= 1; self.level -= 1;
Ok(LambdaSignature::new( Ok(LambdaSignature::new(