fix: Context::union/intersection

This commit is contained in:
Shunsuke Shibayama 2023-04-10 11:51:45 +09:00
parent 3ed863cef6
commit 7c8b8a66a1
4 changed files with 80 additions and 46 deletions

View file

@ -48,7 +48,8 @@ impl Context {
pub(crate) fn eq_tp(&self, lhs: &TyParam, rhs: &TyParam) -> bool {
match (lhs, rhs) {
(TyParam::Type(lhs), TyParam::Type(rhs)) => return self.same_type_of(lhs, rhs),
(TyParam::Type(lhs), TyParam::Type(rhs))
| (TyParam::Erased(lhs), TyParam::Erased(rhs)) => return self.same_type_of(lhs, rhs),
(TyParam::Mono(l), TyParam::Mono(r)) => {
if let (Some(l), Some(r)) = (self.rec_get_const_obj(l), self.rec_get_const_obj(r)) {
return l == r;
@ -502,7 +503,7 @@ impl Context {
(Type, Poly { name, params }) | (Poly { name, params }, Type)
if &name[..] == "Array" || &name[..] == "Set" =>
{
let elem_t = self.convert_tp_into_ty(params[0].clone()).unwrap();
let elem_t = self.convert_tp_into_type(params[0].clone()).unwrap();
self.supertype_of(&Type, &elem_t)
}
(Type, Poly { name, params }) | (Poly { name, params }, Type)
@ -515,7 +516,7 @@ impl Context {
return self.supertype_of(&Type, &arr_t);
} else if let Ok(tps) = Vec::try_from(params[0].clone()) {
for tp in tps {
let Ok(t) = self.convert_tp_into_ty(tp) else {
let Ok(t) = self.convert_tp_into_type(tp) else {
return false;
};
if !self.supertype_of(&Type, &t) {
@ -539,10 +540,10 @@ impl Context {
return false;
};
for (k, v) in dict.into_iter() {
let Ok(k) = self.convert_tp_into_ty(k) else {
let Ok(k) = self.convert_tp_into_type(k) else {
return false;
};
let Ok(v) = self.convert_tp_into_ty(v) else {
let Ok(v) = self.convert_tp_into_type(v) else {
return false;
};
if !self.supertype_of(&Type, &k) || !self.supertype_of(&Type, &v) {
@ -682,8 +683,8 @@ impl Context {
}
// [Int; 2] :> [Int; 3]
if &ln[..] == "Array" || &ln[..] == "Set" {
let lt = self.convert_tp_into_ty(lparams[0].clone()).unwrap();
let rt = self.convert_tp_into_ty(rparams[0].clone()).unwrap();
let lt = self.convert_tp_into_type(lparams[0].clone()).unwrap();
let rt = self.convert_tp_into_type(rparams[0].clone()).unwrap();
let llen = lparams[1].clone();
let rlen = rparams[1].clone();
self.supertype_of(&lt, &rt)
@ -975,16 +976,6 @@ impl Context {
if lhs == rhs {
return lhs.clone();
}
// `?T or ?U` will not be unified
// `Set!(?T, 3) or Set(?T, 3)` wii be unified to Set(?T, 3)
if !lhs.is_unbound_var() && !rhs.is_unbound_var() {
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) => return lhs.clone(), // lhs = rhs
(true, false) => return lhs.clone(), // lhs :> rhs
(false, true) => return rhs.clone(),
(false, false) => {}
}
}
match (lhs, rhs) {
(FreeVar(fv), other) | (other, FreeVar(fv)) if fv.is_linked() => {
self.union(&fv.crack(), other)
@ -1007,21 +998,45 @@ impl Context {
let mut unified_params = vec![];
for (lp, rp) in lps.iter().zip(rps.iter()) {
match (lp, rp) {
(TyParam::Value(ValueObj::Type(l)), TyParam::Value(ValueObj::Type(r))) => {
unified_params.push(TyParam::t(self.union(l.typ(), r.typ())));
}
(TyParam::Value(ValueObj::Type(l)), TyParam::Type(r)) => {
unified_params.push(TyParam::t(self.union(l.typ(), r)));
}
(TyParam::Type(l), TyParam::Value(ValueObj::Type(r))) => {
unified_params.push(TyParam::t(self.union(l, r.typ())));
}
(TyParam::Type(l), TyParam::Type(r)) => {
unified_params.push(TyParam::t(self.union(l, r)))
unified_params.push(TyParam::t(self.union(l, r)));
}
(_, _) => {
if self.eq_tp(lp, rp) {
unified_params.push(lp.clone());
} else {
return or(lhs.clone(), rhs.clone());
return self.simple_union(lhs, rhs);
}
}
}
}
poly(ln, unified_params)
}
(l, r) => or(l.clone(), r.clone()),
_ => self.simple_union(lhs, rhs),
}
}
fn simple_union(&self, lhs: &Type, rhs: &Type) -> Type {
// `?T or ?U` will not be unified
// `Set!(?T(<: Int), 3) or Set(?U(<: Nat), 3)` wii be unified to Set(?T, 3)
if lhs.is_unbound_var() || rhs.is_unbound_var() {
or(lhs.clone(), rhs.clone())
} else {
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) => lhs.clone(), // lhs = rhs
(true, false) => lhs.clone(), // lhs :> rhs
(false, true) => rhs.clone(),
(false, false) => or(lhs.clone(), rhs.clone()),
}
}
}
@ -1040,15 +1055,6 @@ impl Context {
if lhs == rhs {
return lhs.clone();
}
// ?T and ?U will not be unified
if !lhs.is_unbound_var() && !rhs.is_unbound_var() {
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) => return lhs.clone(), // lhs = rhs
(true, false) => return rhs.clone(), // lhs :> rhs
(false, true) => return lhs.clone(),
(false, false) => {}
}
}
match (lhs, rhs) {
(FreeVar(fv), other) | (other, FreeVar(fv)) if fv.is_linked() => {
self.intersection(&fv.crack(), other)
@ -1071,10 +1077,29 @@ impl Context {
Type::Record(rec2) => Type::Record(rec.clone().diff(rec2)),
_ => Type::Never,
},
(l, r) if self.is_trait(l) && self.is_trait(r) => and(l.clone(), r.clone()),
(_, Not(r)) => self.diff(lhs, r),
(Not(l), _) => self.diff(rhs, l),
(_l, _r) => Type::Never,
_ => self.simple_intersection(lhs, rhs),
}
}
fn simple_intersection(&self, lhs: &Type, rhs: &Type) -> Type {
// ?T and ?U will not be unified
if lhs.is_unbound_var() || rhs.is_unbound_var() {
and(lhs.clone(), rhs.clone())
} else {
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) => lhs.clone(), // lhs = rhs
(true, false) => rhs.clone(), // lhs :> rhs
(false, true) => lhs.clone(),
(false, false) => {
if self.is_trait(lhs) && self.is_trait(rhs) {
and(lhs.clone(), rhs.clone())
} else {
Type::Never
}
}
}
}
}