fix: type quantification bugs

This commit is contained in:
Shunsuke Shibayama 2023-02-20 20:38:36 +09:00
parent 6d18fde0b1
commit 6a607870f3
8 changed files with 328 additions and 164 deletions

View file

@ -1,6 +1,8 @@
use std::mem;
use erg_common::set::Set;
use erg_common::traits::{Locational, Stream};
use erg_common::Str;
use erg_common::{assume_unreachable, dict, fn_name, set};
#[allow(unused_imports)]
use erg_common::{fmt_vec, log};
@ -286,26 +288,27 @@ impl Context {
&self,
tp: TyParam,
variance: Variance,
qnames: &Set<Str>,
loc: &impl Locational,
) -> TyCheckResult<TyParam> {
match tp {
TyParam::FreeVar(fv) if fv.is_linked() => {
let inner = fv.unwrap_linked();
self.deref_tp(inner, variance, loc)
self.deref_tp(inner, variance, &set! {}, loc)
}
TyParam::FreeVar(_fv) if self.level == 0 => Err(TyCheckErrors::from(
TyCheckError::dummy_infer_error(self.cfg.input.clone(), fn_name!(), line!()),
)),
TyParam::Type(t) => Ok(TyParam::t(self.deref_tyvar(*t, variance, loc)?)),
TyParam::Type(t) => Ok(TyParam::t(self.deref_tyvar(*t, variance, qnames, loc)?)),
TyParam::App { name, mut args } => {
for param in args.iter_mut() {
*param = self.deref_tp(mem::take(param), variance, loc)?;
*param = self.deref_tp(mem::take(param), variance, qnames, loc)?;
}
Ok(TyParam::App { name, args })
}
TyParam::BinOp { op, lhs, rhs } => {
let lhs = self.deref_tp(*lhs, variance, loc)?;
let rhs = self.deref_tp(*rhs, variance, loc)?;
let lhs = self.deref_tp(*lhs, variance, qnames, loc)?;
let rhs = self.deref_tp(*rhs, variance, qnames, loc)?;
Ok(TyParam::BinOp {
op,
lhs: Box::new(lhs),
@ -313,7 +316,7 @@ impl Context {
})
}
TyParam::UnaryOp { op, val } => {
let val = self.deref_tp(*val, variance, loc)?;
let val = self.deref_tp(*val, variance, qnames, loc)?;
Ok(TyParam::UnaryOp {
op,
val: Box::new(val),
@ -322,14 +325,14 @@ impl Context {
TyParam::Array(tps) => {
let mut new_tps = vec![];
for tp in tps {
new_tps.push(self.deref_tp(tp, variance, loc)?);
new_tps.push(self.deref_tp(tp, variance, qnames, loc)?);
}
Ok(TyParam::Array(new_tps))
}
TyParam::Tuple(tps) => {
let mut new_tps = vec![];
for tp in tps {
new_tps.push(self.deref_tp(tp, variance, loc)?);
new_tps.push(self.deref_tp(tp, variance, qnames, loc)?);
}
Ok(TyParam::Tuple(new_tps))
}
@ -337,8 +340,8 @@ impl Context {
let mut new_dic = dict! {};
for (k, v) in dic.into_iter() {
new_dic.insert(
self.deref_tp(k, variance, loc)?,
self.deref_tp(v, variance, loc)?,
self.deref_tp(k, variance, qnames, loc)?,
self.deref_tp(v, variance, qnames, loc)?,
);
}
Ok(TyParam::Dict(new_dic))
@ -346,7 +349,7 @@ impl Context {
TyParam::Set(set) => {
let mut new_set = set! {};
for v in set.into_iter() {
new_set.insert(self.deref_tp(v, variance, loc)?);
new_set.insert(self.deref_tp(v, variance, qnames, loc)?);
}
Ok(TyParam::Set(new_set))
}
@ -361,16 +364,17 @@ impl Context {
&self,
constraint: Constraint,
variance: Variance,
qnames: &Set<Str>,
loc: &impl Locational,
) -> TyCheckResult<Constraint> {
match constraint {
Constraint::Sandwiched { sub, sup } => Ok(Constraint::new_sandwiched(
self.deref_tyvar(sub, variance, loc)?,
self.deref_tyvar(sup, variance, loc)?,
self.deref_tyvar(sub, variance, qnames, loc)?,
self.deref_tyvar(sup, variance, qnames, loc)?,
)),
Constraint::TypeOf(t) => Ok(Constraint::new_type_of(
self.deref_tyvar(t, variance, qnames, loc)?,
)),
Constraint::TypeOf(t) => {
Ok(Constraint::new_type_of(self.deref_tyvar(t, variance, loc)?))
}
_ => unreachable!(),
}
}
@ -380,6 +384,7 @@ impl Context {
sub_t: Type,
super_t: Type,
variance: Variance,
qnames: &Set<Str>,
loc: &impl Locational,
) -> TyCheckResult<Type> {
// TODO: Subr, ...
@ -418,7 +423,7 @@ impl Context {
}
Ok(poly(rn, tps))
}
(sub_t, super_t) => self.validate_simple_subsup(sub_t, super_t, variance, loc),
(sub_t, super_t) => self.validate_simple_subsup(sub_t, super_t, variance, qnames, loc),
}
}
@ -427,22 +432,23 @@ impl Context {
sub_t: Type,
super_t: Type,
variance: Variance,
qnames: &Set<Str>,
loc: &impl Locational,
) -> TyCheckResult<Type> {
if self.is_trait(&super_t) {
self.check_trait_impl(&sub_t, &super_t, loc)?;
self.check_trait_impl(&sub_t, &super_t, &set! {}, loc)?;
}
// REVIEW: Even if type constraints can be satisfied, implementation may not exist
if self.subtype_of(&sub_t, &super_t) {
let sub_t = if cfg!(feature = "debug") {
sub_t
} else {
self.deref_tyvar(sub_t, variance, loc)?
self.deref_tyvar(sub_t, variance, qnames, loc)?
};
let super_t = if cfg!(feature = "debug") {
super_t
} else {
self.deref_tyvar(super_t, variance, loc)?
self.deref_tyvar(super_t, variance, qnames, loc)?
};
match variance {
Variance::Covariant => Ok(sub_t),
@ -467,12 +473,12 @@ impl Context {
let sub_t = if cfg!(feature = "debug") {
sub_t
} else {
self.deref_tyvar(sub_t, variance, loc)?
self.deref_tyvar(sub_t, variance, qnames, loc)?
};
let super_t = if cfg!(feature = "debug") {
super_t
} else {
self.deref_tyvar(super_t, variance, loc)?
self.deref_tyvar(super_t, variance, qnames, loc)?
};
Err(TyCheckErrors::from(TyCheckError::subtyping_error(
self.cfg.input.clone(),
@ -497,9 +503,19 @@ impl Context {
&self,
t: Type,
variance: Variance,
qnames: &Set<Str>,
loc: &impl Locational,
) -> TyCheckResult<Type> {
match t {
Type::FreeVar(fv) if fv.is_linked() => {
let t = fv.unwrap_linked();
self.deref_tyvar(t, variance, qnames, loc)
}
Type::FreeVar(fv)
if fv.is_generalized() && qnames.contains(&fv.unbound_name().unwrap()) =>
{
Ok(Type::FreeVar(fv))
}
// ?T(:> Nat, <: Int)[n] ==> Nat (self.level <= n)
// ?T(:> Nat, <: Sub ?U(:> {1}))[n] ==> Nat
// ?T(<: Int, :> Add(?T)) ==> Int
@ -511,7 +527,7 @@ impl Context {
// if fv == ?T(<: Int, :> Add(?T)), deref_tyvar(super_t) will cause infinite loop
// so we need to force linking
fv.forced_undoable_link(&sub_t);
let res = self.validate_subsup(sub_t, super_t, variance, loc);
let res = self.validate_subsup(sub_t, super_t, variance, qnames, loc);
fv.undo();
res
} else {
@ -536,15 +552,12 @@ impl Context {
Ok(Type::FreeVar(fv))
} else {
let new_constraint = fv.crack_constraint().clone();
let new_constraint = self.deref_constraint(new_constraint, variance, loc)?;
let new_constraint =
self.deref_constraint(new_constraint, variance, qnames, loc)?;
fv.update_constraint(new_constraint, true);
Ok(Type::FreeVar(fv))
}
}
Type::FreeVar(fv) if fv.is_linked() => {
let t = fv.unwrap_linked();
self.deref_tyvar(t, variance, loc)
}
Type::Poly { name, mut params } => {
let typ = poly(&name, params.clone());
let (_, ctx) = self.get_nominal_type_ctx(&typ).ok_or_else(|| {
@ -558,7 +571,7 @@ impl Context {
})?;
let variances = ctx.type_params_variance();
for (param, variance) in params.iter_mut().zip(variances.into_iter()) {
*param = self.deref_tp(mem::take(param), variance, loc)?;
*param = self.deref_tp(mem::take(param), variance, qnames, loc)?;
}
Ok(Type::Poly { name, params })
}
@ -567,6 +580,7 @@ impl Context {
*param.typ_mut() = self.deref_tyvar(
mem::take(param.typ_mut()),
variance * Contravariant,
qnames,
loc,
)?;
}
@ -574,6 +588,7 @@ impl Context {
*var_args.typ_mut() = self.deref_tyvar(
mem::take(var_args.typ_mut()),
variance * Contravariant,
qnames,
loc,
)?;
}
@ -581,30 +596,27 @@ impl Context {
*d_param.typ_mut() = self.deref_tyvar(
mem::take(d_param.typ_mut()),
variance * Contravariant,
qnames,
loc,
)?;
}
subr.return_t = Box::new(self.deref_tyvar(
mem::take(&mut subr.return_t),
variance * Covariant,
qnames,
loc,
)?);
Ok(Type::Subr(subr))
}
Type::Quantified(subr)
if subr.return_t().map(|ret| !ret.has_qvar()).unwrap_or(false) =>
{
let subr = self.deref_tyvar(*subr, variance, loc)?;
Ok(subr)
}
Type::Quantified(subr) => self.eliminate_needless_quant(*subr, variance, loc),
Type::Ref(t) => {
let t = self.deref_tyvar(*t, variance, loc)?;
let t = self.deref_tyvar(*t, variance, qnames, loc)?;
Ok(ref_(t))
}
Type::RefMut { before, after } => {
let before = self.deref_tyvar(*before, variance, loc)?;
let before = self.deref_tyvar(*before, variance, qnames, loc)?;
let after = if let Some(after) = after {
Some(self.deref_tyvar(*after, variance, loc)?)
Some(self.deref_tyvar(*after, variance, qnames, loc)?)
} else {
None
};
@ -613,31 +625,31 @@ impl Context {
// Type::Callable { .. } => todo!(),
Type::Record(mut rec) => {
for (_, field) in rec.iter_mut() {
*field = self.deref_tyvar(mem::take(field), variance, loc)?;
*field = self.deref_tyvar(mem::take(field), variance, qnames, loc)?;
}
Ok(Type::Record(rec))
}
Type::Refinement(refine) => {
let t = self.deref_tyvar(*refine.t, variance, loc)?;
let t = self.deref_tyvar(*refine.t, variance, qnames, loc)?;
// TODO: deref_predicate
Ok(refinement(refine.var, t, refine.preds))
}
Type::And(l, r) => {
let l = self.deref_tyvar(*l, variance, loc)?;
let r = self.deref_tyvar(*r, variance, loc)?;
let l = self.deref_tyvar(*l, variance, qnames, loc)?;
let r = self.deref_tyvar(*r, variance, qnames, loc)?;
Ok(self.intersection(&l, &r))
}
Type::Or(l, r) => {
let l = self.deref_tyvar(*l, variance, loc)?;
let r = self.deref_tyvar(*r, variance, loc)?;
let l = self.deref_tyvar(*l, variance, qnames, loc)?;
let r = self.deref_tyvar(*r, variance, qnames, loc)?;
Ok(self.union(&l, &r))
}
Type::Not(ty) => {
let ty = self.deref_tyvar(*ty, variance, loc)?;
let ty = self.deref_tyvar(*ty, variance, qnames, loc)?;
Ok(self.complement(&ty))
}
Type::Proj { lhs, rhs } => {
let lhs = self.deref_tyvar(*lhs, variance, loc)?;
let lhs = self.deref_tyvar(*lhs, variance, qnames, loc)?;
self.eval_proj(lhs, rhs, self.level, loc)
}
Type::ProjCall {
@ -645,10 +657,10 @@ impl Context {
attr_name,
args,
} => {
let lhs = self.deref_tp(*lhs, variance, loc)?;
let lhs = self.deref_tp(*lhs, variance, qnames, loc)?;
let mut new_args = vec![];
for arg in args.into_iter() {
new_args.push(self.deref_tp(arg, variance, loc)?);
new_args.push(self.deref_tp(arg, variance, qnames, loc)?);
}
self.eval_proj_call(lhs, attr_name, new_args, self.level, loc)
}
@ -656,13 +668,68 @@ impl Context {
}
}
// here ?T can be eliminated
// ?T -> Int
// ?T, ?U -> K(?U)
// Int -> ?T
// here ?T cannot be eliminated
// ?T -> ?T
// ?T -> K(?T)
// ?T -> ?U(:> ?T)
fn eliminate_needless_quant(
&self,
subr: Type,
variance: Variance,
loc: &impl Locational,
) -> TyCheckResult<Type> {
let Type::Subr(mut subr) = subr else { unreachable!() };
let essential_qnames = subr.essential_qnames();
for param in subr.non_default_params.iter_mut() {
*param.typ_mut() = self.deref_tyvar(
mem::take(param.typ_mut()),
variance * Contravariant,
&essential_qnames,
loc,
)?;
}
if let Some(var_args) = &mut subr.var_params {
*var_args.typ_mut() = self.deref_tyvar(
mem::take(var_args.typ_mut()),
variance * Contravariant,
&essential_qnames,
loc,
)?;
}
for d_param in subr.default_params.iter_mut() {
*d_param.typ_mut() = self.deref_tyvar(
mem::take(d_param.typ_mut()),
variance * Contravariant,
&essential_qnames,
loc,
)?;
}
subr.return_t = Box::new(self.deref_tyvar(
mem::take(&mut subr.return_t),
variance * Covariant,
&essential_qnames,
loc,
)?);
let subr = Type::Subr(subr);
if subr.has_qvar() {
Ok(subr.quantify())
} else {
Ok(subr)
}
}
pub fn readable_type(&self, t: Type, is_parameter: bool) -> Type {
let variance = if is_parameter {
Contravariant
} else {
Covariant
};
self.deref_tyvar(t.clone(), variance, &()).unwrap_or(t)
self.deref_tyvar(t.clone(), variance, &set! {}, &())
.unwrap_or(t)
}
pub(crate) fn trait_impl_exists(&self, class: &Type, trait_: &Type) -> bool {
@ -703,18 +770,19 @@ impl Context {
&self,
class: &Type,
trait_: &Type,
qnames: &Set<Str>,
loc: &impl Locational,
) -> TyCheckResult<()> {
if !self.trait_impl_exists(class, trait_) {
let class = if cfg!(feature = "debug") {
class.clone()
} else {
self.deref_tyvar(class.clone(), Variance::Covariant, loc)?
self.deref_tyvar(class.clone(), Variance::Covariant, qnames, loc)?
};
let trait_ = if cfg!(feature = "debug") {
trait_.clone()
} else {
self.deref_tyvar(trait_.clone(), Variance::Covariant, loc)?
self.deref_tyvar(trait_.clone(), Variance::Covariant, qnames, loc)?
};
Err(TyCheckErrors::from(TyCheckError::no_trait_impl_error(
self.cfg.input.clone(),
@ -756,12 +824,12 @@ impl Context {
let mut params = mem::take(&mut self.params);
let mut methods_list = mem::take(&mut self.methods_list);
for (name, vi) in locals.iter_mut() {
if let Ok(t) = self.deref_tyvar(mem::take(&mut vi.t), Covariant, name) {
if let Ok(t) = self.deref_tyvar(mem::take(&mut vi.t), Covariant, &set! {}, name) {
vi.t = t;
}
}
for (name, vi) in params.iter_mut() {
if let Ok(t) = self.deref_tyvar(mem::take(&mut vi.t), Covariant, name) {
if let Ok(t) = self.deref_tyvar(mem::take(&mut vi.t), Covariant, &set! {}, name) {
vi.t = t;
}
}
@ -773,26 +841,22 @@ impl Context {
self.methods_list = methods_list;
}
fn resolve_params_t(&self, params: &mut hir::Params) -> TyCheckResult<()> {
fn resolve_params_t(&self, params: &mut hir::Params, qnames: &Set<Str>) -> TyCheckResult<()> {
for param in params.non_defaults.iter_mut() {
if !param.vi.t.is_qvar() {
param.vi.t = self.deref_tyvar(mem::take(&mut param.vi.t), Contravariant, param)?;
}
param.vi.t =
self.deref_tyvar(mem::take(&mut param.vi.t), Contravariant, qnames, param)?;
}
if let Some(var_params) = &mut params.var_params {
if !var_params.vi.t.is_qvar() {
var_params.vi.t = self.deref_tyvar(
mem::take(&mut var_params.vi.t),
Contravariant,
var_params.as_ref(),
)?;
}
var_params.vi.t = self.deref_tyvar(
mem::take(&mut var_params.vi.t),
Contravariant,
&set! {},
var_params.as_ref(),
)?;
}
for param in params.defaults.iter_mut() {
if !param.sig.vi.t.is_qvar() {
param.sig.vi.t =
self.deref_tyvar(mem::take(&mut param.sig.vi.t), Contravariant, param)?;
}
param.sig.vi.t =
self.deref_tyvar(mem::take(&mut param.sig.vi.t), Contravariant, qnames, param)?;
self.resolve_expr_t(&mut param.default_val)?;
}
Ok(())
@ -809,7 +873,7 @@ impl Context {
Covariant
};
let t = mem::take(acc.ref_mut_t());
*acc.ref_mut_t() = self.deref_tyvar(t, variance, acc)?;
*acc.ref_mut_t() = self.deref_tyvar(t, variance, &set! {}, acc)?;
}
if let hir::Accessor::Attr(attr) = acc {
self.resolve_expr_t(&mut attr.obj)?;
@ -818,14 +882,14 @@ impl Context {
}
hir::Expr::Array(array) => match array {
hir::Array::Normal(arr) => {
arr.t = self.deref_tyvar(mem::take(&mut arr.t), Covariant, arr)?;
arr.t = self.deref_tyvar(mem::take(&mut arr.t), Covariant, &set! {}, arr)?;
for elem in arr.elems.pos_args.iter_mut() {
self.resolve_expr_t(&mut elem.expr)?;
}
Ok(())
}
hir::Array::WithLength(arr) => {
arr.t = self.deref_tyvar(mem::take(&mut arr.t), Covariant, arr)?;
arr.t = self.deref_tyvar(mem::take(&mut arr.t), Covariant, &set! {}, arr)?;
self.resolve_expr_t(&mut arr.elem)?;
self.resolve_expr_t(&mut arr.len)?;
Ok(())
@ -840,7 +904,7 @@ impl Context {
},
hir::Expr::Tuple(tuple) => match tuple {
hir::Tuple::Normal(tup) => {
tup.t = self.deref_tyvar(mem::take(&mut tup.t), Covariant, tup)?;
tup.t = self.deref_tyvar(mem::take(&mut tup.t), Covariant, &set! {}, tup)?;
for elem in tup.elems.pos_args.iter_mut() {
self.resolve_expr_t(&mut elem.expr)?;
}
@ -849,14 +913,14 @@ impl Context {
},
hir::Expr::Set(set) => match set {
hir::Set::Normal(st) => {
st.t = self.deref_tyvar(mem::take(&mut st.t), Covariant, st)?;
st.t = self.deref_tyvar(mem::take(&mut st.t), Covariant, &set! {}, st)?;
for elem in st.elems.pos_args.iter_mut() {
self.resolve_expr_t(&mut elem.expr)?;
}
Ok(())
}
hir::Set::WithLength(st) => {
st.t = self.deref_tyvar(mem::take(&mut st.t), Covariant, st)?;
st.t = self.deref_tyvar(mem::take(&mut st.t), Covariant, &set! {}, st)?;
self.resolve_expr_t(&mut st.elem)?;
self.resolve_expr_t(&mut st.len)?;
Ok(())
@ -864,7 +928,7 @@ impl Context {
},
hir::Expr::Dict(dict) => match dict {
hir::Dict::Normal(dic) => {
dic.t = self.deref_tyvar(mem::take(&mut dic.t), Covariant, dic)?;
dic.t = self.deref_tyvar(mem::take(&mut dic.t), Covariant, &set! {}, dic)?;
for kv in dic.kvs.iter_mut() {
self.resolve_expr_t(&mut kv.key)?;
self.resolve_expr_t(&mut kv.value)?;
@ -880,16 +944,25 @@ impl Context {
),
},
hir::Expr::Record(record) => {
record.t = self.deref_tyvar(mem::take(&mut record.t), Covariant, record)?;
record.t =
self.deref_tyvar(mem::take(&mut record.t), Covariant, &set! {}, record)?;
for attr in record.attrs.iter_mut() {
match &mut attr.sig {
hir::Signature::Var(var) => {
*var.ref_mut_t() =
self.deref_tyvar(mem::take(var.ref_mut_t()), Covariant, var)?;
*var.ref_mut_t() = self.deref_tyvar(
mem::take(var.ref_mut_t()),
Covariant,
&set! {},
var,
)?;
}
hir::Signature::Subr(subr) => {
*subr.ref_mut_t() =
self.deref_tyvar(mem::take(subr.ref_mut_t()), Covariant, subr)?;
*subr.ref_mut_t() = self.deref_tyvar(
mem::take(subr.ref_mut_t()),
Covariant,
&set! {},
subr,
)?;
}
}
for chunk in attr.body.block.iter_mut() {
@ -900,21 +973,24 @@ impl Context {
}
hir::Expr::BinOp(binop) => {
let t = mem::take(binop.signature_mut_t().unwrap());
*binop.signature_mut_t().unwrap() = self.deref_tyvar(t, Covariant, binop)?;
*binop.signature_mut_t().unwrap() =
self.deref_tyvar(t, Covariant, &set! {}, binop)?;
self.resolve_expr_t(&mut binop.lhs)?;
self.resolve_expr_t(&mut binop.rhs)?;
Ok(())
}
hir::Expr::UnaryOp(unaryop) => {
let t = mem::take(unaryop.signature_mut_t().unwrap());
*unaryop.signature_mut_t().unwrap() = self.deref_tyvar(t, Covariant, unaryop)?;
*unaryop.signature_mut_t().unwrap() =
self.deref_tyvar(t, Covariant, &set! {}, unaryop)?;
self.resolve_expr_t(&mut unaryop.expr)?;
Ok(())
}
hir::Expr::Call(call) => {
if let Some(t) = call.signature_mut_t() {
let t = mem::take(t);
*call.signature_mut_t().unwrap() = self.deref_tyvar(t, Covariant, call)?;
*call.signature_mut_t().unwrap() =
self.deref_tyvar(t, Covariant, &set! {}, call)?;
}
self.resolve_expr_t(&mut call.obj)?;
for arg in call.args.pos_args.iter_mut() {
@ -929,10 +1005,20 @@ impl Context {
Ok(())
}
hir::Expr::Def(def) => {
*def.sig.ref_mut_t() =
self.deref_tyvar(mem::take(def.sig.ref_mut_t()), Covariant, &def.sig)?;
*def.sig.ref_mut_t() = self.deref_tyvar(
mem::take(def.sig.ref_mut_t()),
Covariant,
&set! {},
&def.sig,
)?;
let qnames = if let Type::Quantified(quant) = def.sig.ref_t() {
let Type::Subr(subr) = quant.as_ref() else { unreachable!() };
subr.essential_qnames()
} else {
set! {}
};
if let Some(params) = def.sig.params_mut() {
self.resolve_params_t(params)?;
self.resolve_params_t(params, &qnames)?;
}
for chunk in def.body.block.iter_mut() {
self.resolve_expr_t(chunk)?;
@ -940,10 +1026,15 @@ impl Context {
Ok(())
}
hir::Expr::Lambda(lambda) => {
log!(err "{}", lambda.t);
lambda.t = self.deref_tyvar(mem::take(&mut lambda.t), Covariant, lambda)?;
log!(err "{}", lambda.t);
self.resolve_params_t(&mut lambda.params)?;
lambda.t =
self.deref_tyvar(mem::take(&mut lambda.t), Covariant, &set! {}, lambda)?;
let qnames = if let Type::Quantified(quant) = lambda.ref_t() {
let Type::Subr(subr) = quant.as_ref() else { unreachable!() };
subr.essential_qnames()
} else {
set! {}
};
self.resolve_params_t(&mut lambda.params, &qnames)?;
for chunk in lambda.body.iter_mut() {
self.resolve_expr_t(chunk)?;
}