fix: sub-unification bugs

This commit is contained in:
Shunsuke Shibayama 2023-03-22 15:38:47 +09:00
parent b318395a32
commit 0079aed860
10 changed files with 190 additions and 153 deletions

View file

@ -23,70 +23,69 @@ use Variance::*;
impl Context {
pub const TOP_LEVEL: usize = 1;
fn generalize_tp(&self, free: TyParam, variance: Variance, uninit: bool) -> TyParam {
fn generalize_tp(
&self,
free: TyParam,
variance: Variance,
qnames: &Set<Str>,
uninit: bool,
) -> TyParam {
match free {
TyParam::Type(t) => TyParam::t(self.generalize_t_inner(*t, variance, uninit)),
TyParam::Type(t) => TyParam::t(self.generalize_t_inner(*t, variance, qnames, uninit)),
TyParam::FreeVar(fv) if fv.is_generalized() => TyParam::FreeVar(fv),
TyParam::FreeVar(fv) if fv.is_linked() => {
self.generalize_tp(fv.crack().clone(), variance, uninit)
/*let fv_mut = unsafe { fv.as_ptr().as_mut().unwrap() };
if let FreeKind::Linked(tp) = fv_mut {
*tp = self.generalize_tp(tp.clone(), variance, uninit);
} else {
assume_unreachable!()
}
TyParam::FreeVar(fv)*/
self.generalize_tp(fv.crack().clone(), variance, qnames, uninit)
}
// TODO: Polymorphic generalization
TyParam::FreeVar(fv) if fv.level() > Some(self.level) => {
let constr = self.generalize_constraint(&fv, variance);
let constr = self.generalize_constraint(&fv, qnames, variance);
fv.update_constraint(constr, true);
fv.generalize();
TyParam::FreeVar(fv)
}
TyParam::Array(tps) => TyParam::Array(
tps.into_iter()
.map(|tp| self.generalize_tp(tp, variance, uninit))
.map(|tp| self.generalize_tp(tp, variance, qnames, uninit))
.collect(),
),
TyParam::Tuple(tps) => TyParam::Tuple(
tps.into_iter()
.map(|tp| self.generalize_tp(tp, variance, uninit))
.map(|tp| self.generalize_tp(tp, variance, qnames, uninit))
.collect(),
),
TyParam::Dict(tps) => TyParam::Dict(
tps.into_iter()
.map(|(k, v)| {
(
self.generalize_tp(k, variance, uninit),
self.generalize_tp(v, variance, uninit),
self.generalize_tp(k, variance, qnames, uninit),
self.generalize_tp(v, variance, qnames, uninit),
)
})
.collect(),
),
TyParam::Record(rec) => TyParam::Record(
rec.into_iter()
.map(|(field, tp)| (field, self.generalize_tp(tp, variance, uninit)))
.map(|(field, tp)| (field, self.generalize_tp(tp, variance, qnames, uninit)))
.collect(),
),
TyParam::Lambda(lambda) => {
let nd_params = lambda
.nd_params
.into_iter()
.map(|pt| pt.map_type(|t| self.generalize_t_inner(t, variance, uninit)))
.map(|pt| pt.map_type(|t| self.generalize_t_inner(t, variance, qnames, uninit)))
.collect::<Vec<_>>();
let var_params = lambda
.var_params
.map(|pt| pt.map_type(|t| self.generalize_t_inner(t, variance, uninit)));
let var_params = lambda.var_params.map(|pt| {
pt.map_type(|t| self.generalize_t_inner(t, variance, qnames, uninit))
});
let d_params = lambda
.d_params
.into_iter()
.map(|pt| pt.map_type(|t| self.generalize_t_inner(t, variance, uninit)))
.map(|pt| pt.map_type(|t| self.generalize_t_inner(t, variance, qnames, uninit)))
.collect::<Vec<_>>();
let body = lambda
.body
.into_iter()
.map(|tp| self.generalize_tp(tp, variance, uninit))
.map(|tp| self.generalize_tp(tp, variance, qnames, uninit))
.collect();
TyParam::Lambda(TyParamLambda::new(
lambda.const_,
@ -98,24 +97,26 @@ impl Context {
}
TyParam::FreeVar(_) => free,
TyParam::Proj { obj, attr } => {
let obj = self.generalize_tp(*obj, variance, uninit);
let obj = self.generalize_tp(*obj, variance, qnames, uninit);
TyParam::proj(obj, attr)
}
TyParam::Erased(t) => TyParam::erased(self.generalize_t_inner(*t, variance, uninit)),
TyParam::Erased(t) => {
TyParam::erased(self.generalize_t_inner(*t, variance, qnames, uninit))
}
TyParam::App { name, args } => {
let args = args
.into_iter()
.map(|tp| self.generalize_tp(tp, variance, uninit))
.map(|tp| self.generalize_tp(tp, variance, qnames, uninit))
.collect();
TyParam::App { name, args }
}
TyParam::BinOp { op, lhs, rhs } => {
let lhs = self.generalize_tp(*lhs, variance, uninit);
let rhs = self.generalize_tp(*rhs, variance, uninit);
let lhs = self.generalize_tp(*lhs, variance, qnames, uninit);
let rhs = self.generalize_tp(*rhs, variance, qnames, uninit);
TyParam::bin(op, lhs, rhs)
}
TyParam::UnaryOp { op, val } => {
let val = self.generalize_tp(*val, variance, uninit);
let val = self.generalize_tp(*val, variance, qnames, uninit);
TyParam::unary(op, val)
}
other if other.has_no_unbound_var() => other,
@ -126,7 +127,7 @@ impl Context {
/// Quantification occurs only once in function types.
/// Therefore, this method is called only once at the top level, and `generalize_t_inner` is called inside.
pub(crate) fn generalize_t(&self, free_type: Type) -> Type {
let maybe_unbound_t = self.generalize_t_inner(free_type, Covariant, false);
let maybe_unbound_t = self.generalize_t_inner(free_type, Covariant, &set! {}, false);
if maybe_unbound_t.is_subr() && maybe_unbound_t.has_qvar() {
maybe_unbound_t.quantify()
} else {
@ -141,56 +142,64 @@ impl Context {
/// generalize_t(?T(<: Add(?T(<: Eq(?T(<: ...)))) -> ?T) == |'T <: Add('T)| 'T -> 'T
/// generalize_t(?T(<: TraitX) -> Int) == TraitX -> Int // 戻り値に現れないなら量化しない
/// ```
fn generalize_t_inner(&self, free_type: Type, variance: Variance, uninit: bool) -> Type {
fn generalize_t_inner(
&self,
free_type: Type,
variance: Variance,
qnames: &Set<Str>,
uninit: bool,
) -> Type {
match free_type {
FreeVar(fv) if fv.is_linked() => {
self.generalize_t_inner(fv.crack().clone(), variance, uninit)
/*let fv_mut = unsafe { fv.as_ptr().as_mut().unwrap() };
if let FreeKind::Linked(t) = fv_mut {
*t = self.generalize_t_inner(t.clone(), variance, uninit);
} else {
assume_unreachable!()
}
Type::FreeVar(fv)*/
self.generalize_t_inner(fv.crack().clone(), variance, qnames, uninit)
}
FreeVar(fv) if fv.is_generalized() => Type::FreeVar(fv),
// TODO: Polymorphic generalization
FreeVar(fv) if fv.level().unwrap() > self.level => {
if uninit {
// use crate::ty::free::GENERIC_LEVEL;
// return named_free_var(fv.unbound_name().unwrap(), GENERIC_LEVEL, Constraint::Uninited);
fv.generalize();
return Type::FreeVar(fv);
}
if let Some((l, r)) = fv.get_subsup() {
if let Some((sub, sup)) = fv.get_subsup() {
// |Int <: T <: Int| T -> T ==> Int -> Int
if l == r {
let t = self.generalize_t_inner(l, variance, uninit);
if sub == sup {
let t = self.generalize_t_inner(sub, variance, qnames, uninit);
fv.forced_link(&t);
FreeVar(fv)
} else if r != Obj && self.is_class(&r) && variance == Contravariant {
} else if sup != Obj
&& !qnames.contains(&fv.unbound_name().unwrap())
&& variance == Contravariant
{
// |T <: Bool| T -> Int ==> Bool -> Int
self.generalize_t_inner(r, variance, uninit)
} else if l != Never && self.is_class(&l) && variance == Covariant {
self.generalize_t_inner(sup, variance, qnames, uninit)
} else if sub != Never
&& !qnames.contains(&fv.unbound_name().unwrap())
&& variance == Covariant
{
// |T :> Int| X -> T ==> X -> Int
self.generalize_t_inner(l, variance, uninit)
self.generalize_t_inner(sub, variance, qnames, uninit)
} else {
fv.update_constraint(self.generalize_constraint(&fv, variance), true);
fv.update_constraint(
self.generalize_constraint(&fv, qnames, variance),
true,
);
fv.generalize();
Type::FreeVar(fv)
}
} else {
// ?S(: Str) => 'S
fv.update_constraint(self.generalize_constraint(&fv, variance), true);
fv.update_constraint(self.generalize_constraint(&fv, qnames, variance), true);
fv.generalize();
Type::FreeVar(fv)
}
}
Subr(mut subr) => {
let qnames = subr.essential_qnames();
subr.non_default_params.iter_mut().for_each(|nd_param| {
*nd_param.typ_mut() = self.generalize_t_inner(
mem::take(nd_param.typ_mut()),
Contravariant,
&qnames,
uninit,
);
});
@ -198,6 +207,7 @@ impl Context {
*var_args.typ_mut() = self.generalize_t_inner(
mem::take(var_args.typ_mut()),
Contravariant,
&qnames,
uninit,
);
}
@ -205,10 +215,11 @@ impl Context {
*d_param.typ_mut() = self.generalize_t_inner(
mem::take(d_param.typ_mut()),
Contravariant,
&qnames,
uninit,
);
});
let return_t = self.generalize_t_inner(*subr.return_t, Covariant, uninit);
let return_t = self.generalize_t_inner(*subr.return_t, Covariant, &qnames, uninit);
subr_t(
subr.kind,
subr.non_default_params,
@ -220,30 +231,34 @@ impl Context {
Record(rec) => {
let fields = rec
.into_iter()
.map(|(name, t)| (name, self.generalize_t_inner(t, variance, uninit)))
.map(|(name, t)| (name, self.generalize_t_inner(t, variance, qnames, uninit)))
.collect();
Type::Record(fields)
}
Callable { .. } => todo!(),
Ref(t) => ref_(self.generalize_t_inner(*t, variance, uninit)),
Ref(t) => ref_(self.generalize_t_inner(*t, variance, qnames, uninit)),
RefMut { before, after } => {
let after = after.map(|aft| self.generalize_t_inner(*aft, variance, uninit));
ref_mut(self.generalize_t_inner(*before, variance, uninit), after)
let after =
after.map(|aft| self.generalize_t_inner(*aft, variance, qnames, uninit));
ref_mut(
self.generalize_t_inner(*before, variance, qnames, uninit),
after,
)
}
Refinement(refine) => {
let t = self.generalize_t_inner(*refine.t, variance, uninit);
let pred = self.generalize_pred(*refine.pred, variance, uninit);
let t = self.generalize_t_inner(*refine.t, variance, qnames, uninit);
let pred = self.generalize_pred(*refine.pred, variance, qnames, uninit);
refinement(refine.var, t, pred)
}
Poly { name, mut params } => {
let params = params
.iter_mut()
.map(|p| self.generalize_tp(mem::take(p), variance, uninit))
.map(|p| self.generalize_tp(mem::take(p), variance, qnames, uninit))
.collect::<Vec<_>>();
poly(name, params)
}
Proj { lhs, rhs } => {
let lhs = self.generalize_t_inner(*lhs, variance, uninit);
let lhs = self.generalize_t_inner(*lhs, variance, qnames, uninit);
proj(lhs, rhs)
}
ProjCall {
@ -251,83 +266,94 @@ impl Context {
attr_name,
mut args,
} => {
let lhs = self.generalize_tp(*lhs, variance, uninit);
let lhs = self.generalize_tp(*lhs, variance, qnames, uninit);
for arg in args.iter_mut() {
*arg = self.generalize_tp(mem::take(arg), variance, uninit);
*arg = self.generalize_tp(mem::take(arg), variance, qnames, uninit);
}
proj_call(lhs, attr_name, args)
}
And(l, r) => {
let l = self.generalize_t_inner(*l, variance, uninit);
let r = self.generalize_t_inner(*r, variance, uninit);
let l = self.generalize_t_inner(*l, variance, qnames, uninit);
let r = self.generalize_t_inner(*r, variance, qnames, uninit);
// not `self.intersection` because types are generalized
and(l, r)
}
Or(l, r) => {
let l = self.generalize_t_inner(*l, variance, uninit);
let r = self.generalize_t_inner(*r, variance, uninit);
let l = self.generalize_t_inner(*l, variance, qnames, uninit);
let r = self.generalize_t_inner(*r, variance, qnames, uninit);
// not `self.union` because types are generalized
or(l, r)
}
Not(l) => not(self.generalize_t_inner(*l, variance, uninit)),
Not(l) => not(self.generalize_t_inner(*l, variance, qnames, uninit)),
Structural(t) => self
.generalize_t_inner(*t, variance, uninit)
.generalize_t_inner(*t, variance, qnames, uninit)
.structuralize(),
// REVIEW: その他何でもそのまま通していいのか?
other => other,
}
}
fn generalize_constraint<T: CanbeFree>(&self, fv: &Free<T>, variance: Variance) -> Constraint {
fn generalize_constraint<T: CanbeFree>(
&self,
fv: &Free<T>,
qnames: &Set<Str>,
variance: Variance,
) -> Constraint {
if let Some((sub, sup)) = fv.get_subsup() {
let sub = self.generalize_t_inner(sub, variance, true);
let sup = self.generalize_t_inner(sup, variance, true);
let sub = self.generalize_t_inner(sub, variance, qnames, true);
let sup = self.generalize_t_inner(sup, variance, qnames, true);
Constraint::new_sandwiched(sub, sup)
} else if let Some(ty) = fv.get_type() {
let t = self.generalize_t_inner(ty, variance, true);
let t = self.generalize_t_inner(ty, variance, qnames, true);
Constraint::new_type_of(t)
} else {
unreachable!()
}
}
fn generalize_pred(&self, pred: Predicate, variance: Variance, uninit: bool) -> Predicate {
fn generalize_pred(
&self,
pred: Predicate,
variance: Variance,
qnames: &Set<Str>,
uninit: bool,
) -> Predicate {
match pred {
Predicate::Const(_) => pred,
Predicate::Value(ValueObj::Type(mut typ)) => {
*typ.typ_mut() =
self.generalize_t_inner(mem::take(typ.typ_mut()), variance, uninit);
self.generalize_t_inner(mem::take(typ.typ_mut()), variance, qnames, uninit);
Predicate::Value(ValueObj::Type(typ))
}
Predicate::Value(_) => pred,
Predicate::Equal { lhs, rhs } => {
let rhs = self.generalize_tp(rhs, variance, uninit);
let rhs = self.generalize_tp(rhs, variance, qnames, uninit);
Predicate::eq(lhs, rhs)
}
Predicate::GreaterEqual { lhs, rhs } => {
let rhs = self.generalize_tp(rhs, variance, uninit);
let rhs = self.generalize_tp(rhs, variance, qnames, uninit);
Predicate::ge(lhs, rhs)
}
Predicate::LessEqual { lhs, rhs } => {
let rhs = self.generalize_tp(rhs, variance, uninit);
let rhs = self.generalize_tp(rhs, variance, qnames, uninit);
Predicate::le(lhs, rhs)
}
Predicate::NotEqual { lhs, rhs } => {
let rhs = self.generalize_tp(rhs, variance, uninit);
let rhs = self.generalize_tp(rhs, variance, qnames, uninit);
Predicate::ne(lhs, rhs)
}
Predicate::And(lhs, rhs) => {
let lhs = self.generalize_pred(*lhs, variance, uninit);
let rhs = self.generalize_pred(*rhs, variance, uninit);
let lhs = self.generalize_pred(*lhs, variance, qnames, uninit);
let rhs = self.generalize_pred(*rhs, variance, qnames, uninit);
Predicate::and(lhs, rhs)
}
Predicate::Or(lhs, rhs) => {
let lhs = self.generalize_pred(*lhs, variance, uninit);
let rhs = self.generalize_pred(*rhs, variance, uninit);
let lhs = self.generalize_pred(*lhs, variance, qnames, uninit);
let rhs = self.generalize_pred(*rhs, variance, qnames, uninit);
Predicate::or(lhs, rhs)
}
Predicate::Not(pred) => {
let pred = self.generalize_pred(*pred, variance, uninit);
let pred = self.generalize_pred(*pred, variance, qnames, uninit);
!pred
}
}