fix: union/intersection types bugs

This commit is contained in:
Shunsuke Shibayama 2023-03-17 10:18:23 +09:00
parent 3ff0eb8f65
commit fd99524bbe
6 changed files with 93 additions and 12 deletions

View file

@ -339,10 +339,10 @@ impl Context {
}
// NG: expr_t: Nat, union_pat_t: {1, 2}
// OK: expr_t: Int, union_pat_t: {1} or 'T
if self
.sub_unify(match_target_expr_t, &union_pat_t, &pos_args[0], None)
.is_err()
{
if let Err(err) = self.sub_unify(match_target_expr_t, &union_pat_t, &pos_args[0], None) {
if cfg!(feature = "debug") {
eprintln!("match error: {err}");
}
return Err(TyCheckErrors::from(TyCheckError::match_error(
self.cfg.input.clone(),
line!() as usize,

View file

@ -24,6 +24,7 @@ use Type::*;
use ValueObj::{Inf, NegInf};
impl Context {
/// occur(?T, ?T) ==> Error
/// occur(X -> ?T, ?T) ==> Error
/// occur(?T, ?T -> X) ==> Error
/// occur(?T, Option(?T)) ==> Error
@ -937,6 +938,15 @@ impl Context {
params: sup_params, ..
},
) => self.nominal_sub_unify(maybe_sub, maybe_sup, sup_params, loc),
(Type::Or(l1, r1), Type::Or(l2, r2)) | (Type::And(l1, r1), Type::And(l2, r2)) => {
if self.subtype_of(l1, l2) && self.subtype_of(r1, r2) {
self.sub_unify(l1, l2, loc, param_name)?;
self.sub_unify(r1, r2, loc, param_name)
} else {
self.sub_unify(l1, r2, loc, param_name)?;
self.sub_unify(r1, l2, loc, param_name)
}
}
// (X or Y) <: Z is valid when X <: Z and Y <: Z
(Type::Or(l, r), _) => {
self.sub_unify(l, maybe_sup, loc, param_name)?;

View file

@ -413,11 +413,33 @@ pub fn refinement(var: Str, t: Type, pred: Predicate) -> Type {
}
pub fn and(lhs: Type, rhs: Type) -> Type {
Type::And(Box::new(lhs), Box::new(rhs))
match (lhs, rhs) {
(Type::And(l, r), other) | (other, Type::And(l, r)) => {
if l.as_ref() == &other {
and(*r, other)
} else if r.as_ref() == &other {
and(*l, other)
} else {
Type::And(Box::new(Type::And(l, r)), Box::new(other))
}
}
(lhs, rhs) => Type::And(Box::new(lhs), Box::new(rhs)),
}
}
pub fn or(lhs: Type, rhs: Type) -> Type {
Type::Or(Box::new(lhs), Box::new(rhs))
match (lhs, rhs) {
(Type::Or(l, r), other) | (other, Type::Or(l, r)) => {
if l.as_ref() == &other {
or(*r, other)
} else if r.as_ref() == &other {
or(*l, other)
} else {
Type::Or(Box::new(Type::Or(l, r)), Box::new(other))
}
}
(lhs, rhs) => Type::Or(Box::new(lhs), Box::new(rhs)),
}
}
pub fn not(ty: Type) -> Type {

View file

@ -809,7 +809,7 @@ impl PartialEq for Type {
(Self::Refinement(l), Self::Refinement(r)) => l == r,
(Self::Quantified(l), Self::Quantified(r)) => l == r,
(Self::And(ll, lr), Self::And(rl, rr)) | (Self::Or(ll, lr), Self::Or(rl, rr)) => {
ll == rl && lr == rr
(ll == rl && lr == rr) || (ll == rr && lr == rl)
}
(Self::Not(l), Self::Not(r)) => l == r,
(
@ -846,6 +846,17 @@ impl PartialEq for Type {
(_self, Self::FreeVar(fv)) if fv.is_linked() => _self == &*fv.crack(),
(Self::FreeVar(l), Self::FreeVar(r)) => l == r,
(Self::Failure, Self::Failure) | (Self::Uninited, Self::Uninited) => true,
// NoneType == {None}
(Self::NoneType, Self::Refinement(refine))
| (Self::Refinement(refine), Self::NoneType) => {
matches!(
refine.pred.as_ref(),
Predicate::Equal {
rhs: TyParam::Value(ValueObj::None),
..
}
)
}
_ => false,
}
}
@ -1350,9 +1361,10 @@ impl StructuralEq for Type {
.all(|(a, b)| a.structural_eq(b))
}
(Self::Structural(l), Self::Structural(r)) => l.structural_eq(r),
// TODO: commutative
(Self::And(l, r), Self::And(l2, r2)) => l.structural_eq(l2) && r.structural_eq(r2),
(Self::Or(l, r), Self::Or(l2, r2)) => l.structural_eq(l2) && r.structural_eq(r2),
(Self::And(l, r), Self::And(l2, r2)) | (Self::Or(l, r), Self::Or(l2, r2)) => {
(l.structural_eq(l2) && r.structural_eq(r2))
|| (l.structural_eq(r2) && r.structural_eq(l2))
}
(Self::Not(ty), Self::Not(ty2)) => ty.structural_eq(ty2),
_ => self == other,
}

View file

@ -148,7 +148,22 @@ impl Predicate {
(p, Predicate::Value(ValueObj::Bool(true))) => p,
(Predicate::Value(ValueObj::Bool(false)), _)
| (_, Predicate::Value(ValueObj::Bool(false))) => Predicate::FALSE,
(p1, p2) => Self::And(Box::new(p1), Box::new(p2)),
(Predicate::And(l, r), other) | (other, Predicate::And(l, r)) => {
if l.as_ref() == &other {
*r & other
} else if r.as_ref() == &other {
*l & other
} else {
Self::And(Box::new(Self::And(l, r)), Box::new(other))
}
}
(p1, p2) => {
if p1 == p2 {
p1
} else {
Self::And(Box::new(p1), Box::new(p2))
}
}
}
}
@ -158,7 +173,22 @@ impl Predicate {
| (_, Predicate::Value(ValueObj::Bool(true))) => Predicate::TRUE,
(Predicate::Value(ValueObj::Bool(false)), p) => p,
(p, Predicate::Value(ValueObj::Bool(false))) => p,
(p1, p2) => Self::Or(Box::new(p1), Box::new(p2)),
(Predicate::Or(l, r), other) | (other, Predicate::Or(l, r)) => {
if l.as_ref() == &other {
*r | other
} else if r.as_ref() == &other {
*l | other
} else {
Self::Or(Box::new(Self::Or(l, r)), Box::new(other))
}
}
(p1, p2) => {
if p1 == p2 {
p1
} else {
Self::Or(Box::new(p1), Box::new(p2))
}
}
}
}

View file

@ -12,3 +12,10 @@ print_to_str!|S <: Show|(s: S): Str =
s.to_str()
discard print_to_str!(1)
add1 x: Int = x + 1
then|T|(x: T or NoneType, f: (a: T) -> T) =
match x:
None -> x
(y: T) -> f y
assert then(1, add1) == 2