fix: type-variable union bugs

This commit is contained in:
Shunsuke Shibayama 2023-04-24 21:48:05 +09:00
parent 0e8dee3cbf
commit a74309cbb3
9 changed files with 219 additions and 134 deletions

View file

@ -73,6 +73,15 @@ impl<T, E> Triple<T, E> {
}
}
impl<T> Triple<T, T> {
pub fn either(self) -> Option<T> {
match self {
Triple::None => None,
Triple::Ok(a) | Triple::Err(a) => Some(a),
}
}
}
impl<T, E: std::error::Error> Triple<T, E> {
#[track_caller]
pub fn unwrap(self) -> T {

View file

@ -5,11 +5,11 @@ use erg_common::dict::Dict;
use erg_common::error::MultiErrorDisplay;
use erg_common::style::colors::DEBUG_ERROR;
use erg_common::traits::StructuralEq;
use erg_common::Str;
use erg_common::{assume_unreachable, log};
use erg_common::{Str, Triple};
use crate::ty::constructors::{and, bounded, not, or, poly};
use crate::ty::free::{Constraint, FreeKind};
use crate::ty::free::{Constraint, FreeKind, FreeTyVar};
use crate::ty::typaram::{OpKind, TyParam, TyParamOrdering};
use crate::ty::value::ValueObj;
use crate::ty::value::ValueObj::Inf;
@ -985,10 +985,21 @@ impl Context {
sub: sub2,
sup: sup2,
},
) => match (self.max(sub, sub2), self.min(sup, sup2)) {
) => match (self.max(sub, sub2).either(), self.min(sup, sup2).either()) {
(Some(sub), Some(sup)) => bounded(sub.clone(), sup.clone()),
_ => self.simple_union(lhs, rhs),
},
(other, or @ Or(l, r)) | (or @ Or(l, r), other) => {
if &self.union(other, l) == l.as_ref() || &self.union(other, r) == r.as_ref() {
or.clone()
} else if &self.union(other, l) == other {
self.union(other, r)
} else if &self.union(other, r) == other {
self.union(other, l)
} else {
self.simple_union(lhs, rhs)
}
}
(t, Type::Never) | (Type::Never, t) => t.clone(),
// Array({1, 2}, 2), Array({3, 4}, 2) ==> Array({1, 2, 3, 4}, 2)
(
@ -1049,12 +1060,31 @@ impl Context {
}
}
/// ```erg
/// simple_union(?T, ?U) == ?T or ?U
/// union(Set!(?T(<: Int), 3), Set(?U(<: Nat), 3)) == Set(?T, 3)
/// simple_union(?T(<: Int), Int) == Int or ?T
/// simple_union(?T(:> Int), Int) == ?T
/// ```
fn simple_union(&self, lhs: &Type, rhs: &Type) -> Type {
// `?T or ?U` will not be unified
// `Set!(?T(<: Int), 3) or Set(?U(<: Nat), 3)` wii be unified to Set(?T, 3)
if lhs.is_unbound_var() || rhs.is_unbound_var() {
or(lhs.clone(), rhs.clone())
if let Ok(free) = <&FreeTyVar>::try_from(lhs) {
if !rhs.is_totally_unbound() && self.supertype_of(&free.get_sub().unwrap_or(Never), rhs)
{
lhs.clone()
} else {
or(lhs.clone(), rhs.clone())
}
} else if let Ok(free) = <&FreeTyVar>::try_from(rhs) {
if !lhs.is_totally_unbound() && self.supertype_of(&free.get_sub().unwrap_or(Never), lhs)
{
rhs.clone()
} else {
or(lhs.clone(), rhs.clone())
}
} else {
if lhs.is_totally_unbound() || rhs.is_totally_unbound() {
return or(lhs.clone(), rhs.clone());
}
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) => lhs.clone(), // lhs = rhs
(true, false) => lhs.clone(), // lhs :> rhs
@ -1388,29 +1418,29 @@ impl Context {
/// Return None if they are not related
/// lhsとrhsが包含関係にあるとき小さいほうを返す
/// 関係なければNoneを返す
pub(crate) fn min<'t>(&self, lhs: &'t Type, rhs: &'t Type) -> Option<&'t Type> {
pub(crate) fn min<'t>(&self, lhs: &'t Type, rhs: &'t Type) -> Triple<&'t Type, &'t Type> {
// If they are the same, either one can be returned.
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) | (true, false) => Some(rhs),
(false, true) => Some(lhs),
(false, false) => None,
(true, true) | (true, false) => Triple::Err(rhs),
(false, true) => Triple::Ok(lhs),
(false, false) => Triple::None,
}
}
pub(crate) fn max<'t>(&self, lhs: &'t Type, rhs: &'t Type) -> Option<&'t Type> {
pub(crate) fn max<'t>(&self, lhs: &'t Type, rhs: &'t Type) -> Triple<&'t Type, &'t Type> {
// If they are the same, either one can be returned.
match (self.supertype_of(lhs, rhs), self.subtype_of(lhs, rhs)) {
(true, true) | (true, false) => Some(lhs),
(false, true) => Some(rhs),
(false, false) => None,
(true, true) | (true, false) => Triple::Ok(lhs),
(false, true) => Triple::Err(rhs),
(false, false) => Triple::None,
}
}
pub(crate) fn cmp_t<'t>(&self, lhs: &'t Type, rhs: &'t Type) -> TyParamOrdering {
match self.min(lhs, rhs) {
Some(l) if l == lhs => TyParamOrdering::Less,
Some(_) => TyParamOrdering::Greater,
None => TyParamOrdering::NoRelation,
Triple::Ok(_) => TyParamOrdering::Less,
Triple::Err(_) => TyParamOrdering::Greater,
Triple::None => TyParamOrdering::NoRelation,
}
}
}

View file

@ -1235,15 +1235,14 @@ impl Context {
/// ```erg
/// squash_tyvar(?1 or ?2) == ?1(== ?2)
/// squash_tyvar(?T or ?U) == ?T or ?U
/// squash_tyvar(?T or NoneType) == ?T or Nonetype
/// ```
pub(crate) fn squash_tyvar(&self, typ: Type) -> Type {
match typ {
Type::Or(l, r) => {
let l = self.squash_tyvar(*l);
let r = self.squash_tyvar(*r);
if l.is_named_unbound_var() && r.is_named_unbound_var() {
self.union(&l, &r)
} else {
if l.is_unnamed_unbound_var() && r.is_unnamed_unbound_var() {
match (self.subtype_of(&l, &r), self.subtype_of(&r, &l)) {
(true, true) | (true, false) => {
let _ = self.sub_unify(&l, &r, &(), None);
@ -1253,8 +1252,8 @@ impl Context {
}
_ => {}
}
self.union(&l, &r)
}
self.union(&l, &r)
}
other => other,
}

View file

@ -23,14 +23,17 @@ use Type::*;
use ValueObj::{Inf, NegInf};
impl Context {
/// ```erg
/// occur(?T, ?T) ==> OK
/// occur(X -> ?T, ?T) ==> Error
/// occur(X -> ?T, X -> ?T) ==> OK
/// occur(?T, ?T -> X) ==> Error
/// occur(?T, Option(?T)) ==> Error
/// occur(?T or ?U, ?T) ==> OK
/// occur(?T or Int, Int or ?T) ==> OK
/// occur(?T(<: Str) or ?U(<: Int), ?T(<: Str)) ==> Error
/// occur(?T, ?T.Output) ==> OK
/// ```
pub(crate) fn occur(
&self,
maybe_sub: &Type,
@ -118,10 +121,10 @@ impl Context {
}
Ok(())
}
(Or(l, r), Or(l2, r2)) | (And(l, r), And(l2, r2)) => {
self.occur(l, l2, loc)?;
self.occur(r, r2, loc)
}
(Or(l, r), Or(l2, r2)) | (And(l, r), And(l2, r2)) => self
.occur(l, l2, loc)
.and(self.occur(r, r2, loc))
.or(self.occur(l, r2, loc).and(self.occur(r, l2, loc))),
(lhs, Or(l, r)) | (lhs, And(l, r)) => {
self.occur_inner(lhs, l, loc)?;
self.occur_inner(lhs, r, loc)
@ -602,7 +605,10 @@ impl Context {
if maybe_sub == &Type::Failure || maybe_sup == &Type::Failure {
return Ok(());
}
self.occur(maybe_sub, maybe_sup, loc)?;
self.occur(maybe_sub, maybe_sup, loc).map_err(|err| {
log!(err "occur error: {maybe_sub} / {maybe_sup}");
err
})?;
let maybe_sub_is_sub = self.subtype_of(maybe_sub, maybe_sup);
if !maybe_sub_is_sub {
log!(err "{maybe_sub} !<: {maybe_sup}");
@ -668,26 +674,7 @@ impl Context {
}
sup_fv.undo();
let intersec = self.intersection(&lsup, &rsup);
let new_constraint = if intersec != Type::Never {
let union = self.union(&lsub, &rsub);
if !lsub.has_union_type() && !rsub.has_union_type() && union.has_union_type() {
let (l, r) = union.union_pair().unwrap_or((lsub, rsub));
let unified = self.unify(&l, &r);
if unified.is_none() {
return Err(TyCheckErrors::from(
TyCheckError::implicit_widening_error(
self.cfg.input.clone(),
line!() as usize,
loc.loc(),
self.caused_by(),
maybe_sub,
maybe_sup,
),
));
}
}
Constraint::new_sandwiched(union, intersec)
} else {
if intersec == Type::Never {
return Err(TyCheckErrors::from(TyCheckError::subtyping_error(
self.cfg.input.clone(),
line!() as usize,
@ -696,7 +683,23 @@ impl Context {
loc.loc(),
self.caused_by(),
)));
};
}
let union = self.union(&lsub, &rsub);
if lsub.union_size().max(rsub.union_size()) < union.union_size() {
let (l, r) = union.union_pair().unwrap_or((lsub, rsub));
let unified = self.unify(&l, &r);
if unified.is_none() {
return Err(TyCheckErrors::from(TyCheckError::implicit_widening_error(
self.cfg.input.clone(),
line!() as usize,
loc.loc(),
self.caused_by(),
maybe_sub,
maybe_sup,
)));
}
}
let new_constraint = Constraint::new_sandwiched(union, intersec);
match sub_fv
.level()
.unwrap_or(GENERIC_LEVEL)
@ -723,6 +726,26 @@ impl Context {
}
Ok(())
}
// (Int or ?T) <: (?U or Int)
// OK: (Int <: Int); (?T <: ?U)
// NG: (Int <: ?U); (?T <: Int)
(Or(l1, r1), Or(l2, r2)) | (And(l1, r1), And(l2, r2)) => {
if self.subtype_of(l1, l2) && self.subtype_of(r1, r2) {
let (l_sup, r_sup) = if self.subtype_of(l1, r2)
&& !l1.is_unbound_var()
&& !r2.is_unbound_var()
{
(r2, l2)
} else {
(l2, r2)
};
self.sub_unify(l1, l_sup, loc, param_name)?;
self.sub_unify(r1, r_sup, loc, param_name)
} else {
self.sub_unify(l1, r2, loc, param_name)?;
self.sub_unify(r1, l2, loc, param_name)
}
}
// NG: Nat <: ?T or Int ==> Nat or Int (?T = Nat)
// OK: Nat <: ?T or Int ==> ?T or Int
(sub, Or(l, r))
@ -747,6 +770,7 @@ impl Context {
{
Ok(())
}
(_, FreeVar(sup_fv)) if sup_fv.is_generalized() => Ok(()),
(_, FreeVar(sup_fv)) if sup_fv.is_unbound() => {
// * sub_unify(Nat, ?E(<: Eq(?E)))
// sub !<: l => OK (sub will widen)
@ -770,9 +794,8 @@ impl Context {
// Expanding to an Or-type is prohibited by default
// This increases the quality of error reporting
// (Try commenting out this part and run tests/should_err/subtyping.er to see the error report changes on lines 29-30)
if !maybe_sub.has_union_type()
&& !sub.has_union_type()
&& new_sub.has_union_type()
if maybe_sub.union_size().max(sub.union_size()) < new_sub.union_size()
&& new_sub.union_types().iter().any(|t| !t.is_unbound_var())
{
let (l, r) = new_sub.union_pair().unwrap_or((maybe_sub.clone(), sub));
let unified = self.unify(&l, &r);
@ -825,6 +848,7 @@ impl Context {
(FreeVar(sub_fv), Ref(sup)) if sub_fv.is_unbound() => {
self.sub_unify(maybe_sub, sup, loc, param_name)
}
(FreeVar(sub_fv), _) if sub_fv.is_generalized() => Ok(()),
(FreeVar(sub_fv), _) if sub_fv.is_unbound() => {
// sub !<: r => Error
// * sub_unify(?T(:> Int, <: _), Nat): (/* Error */)
@ -844,7 +868,7 @@ impl Context {
return Ok(());
}
let sub = mem::take(&mut sub);
let new_sup = if let Some(new_sup) = self.min(&sup, maybe_sup) {
let new_sup = if let Some(new_sup) = self.min(&sup, maybe_sup).either() {
new_sup.clone()
} else {
self.intersection(&sup, maybe_sup)
@ -1050,15 +1074,6 @@ impl Context {
}
Ok(())
}
(Or(l1, r1), Or(l2, r2)) | (Type::And(l1, r1), Type::And(l2, r2)) => {
if self.subtype_of(l1, l2) && self.subtype_of(r1, r2) {
self.sub_unify(l1, l2, loc, param_name)?;
self.sub_unify(r1, r2, loc, param_name)
} else {
self.sub_unify(l1, r2, loc, param_name)?;
self.sub_unify(r1, l2, loc, param_name)
}
}
// (X or Y) <: Z is valid when X <: Z and Y <: Z
(Or(l, r), _) => {
self.sub_unify(l, maybe_sup, loc, param_name)?;
@ -1070,14 +1085,22 @@ impl Context {
self.sub_unify(maybe_sub, r, loc, param_name)
}
// (X and Y) <: Z is valid when X <: Z or Y <: Z
(And(l, r), _) => self
.sub_unify(l, maybe_sup, loc, param_name)
.or_else(|_e| self.sub_unify(r, maybe_sup, loc, param_name)),
(And(l, r), _) => {
if self.subtype_of(l, maybe_sup) {
self.sub_unify(l, maybe_sup, loc, param_name)
} else {
self.sub_unify(r, maybe_sup, loc, param_name)
}
}
// X <: (Y or Z) is valid when X <: Y or X <: Z
(_, Or(l, r)) => self
.sub_unify(maybe_sub, l, loc, param_name)
.or_else(|_e| self.sub_unify(maybe_sub, r, loc, param_name)),
(Ref(l), Ref(r)) => self.sub_unify(l, r, loc, param_name),
(_, Or(l, r)) => {
if self.subtype_of(maybe_sub, l) {
self.sub_unify(maybe_sub, l, loc, param_name)
} else {
self.sub_unify(maybe_sub, r, loc, param_name)
}
}
(Ref(sub), Ref(sup)) => self.sub_unify(sub, sup, loc, param_name),
(_, Ref(t)) => self.sub_unify(maybe_sub, t, loc, param_name),
(RefMut { before: l, .. }, RefMut { before: r, .. }) => {
self.sub_unify(l, r, loc, param_name)
@ -1212,7 +1235,7 @@ impl Context {
if self.supertype_of(&r_sup, &Obj) {
continue;
}
if let Some(t) = self.max(&l_sup, &r_sup) {
if let Some(t) = self.max(&l_sup, &r_sup).either() {
return Some(t.clone());
}
}

View file

@ -467,6 +467,10 @@ impl<T> FreeKind<T> {
matches!(self, Self::NamedUnbound { .. })
}
pub const fn is_unnamed_unbound(&self) -> bool {
matches!(self, Self::Unbound { .. })
}
pub const fn is_undoable_linked(&self) -> bool {
matches!(self, Self::UndoableLinked { .. })
}
@ -772,6 +776,10 @@ impl<T> Free<T> {
self.borrow().is_named_unbound()
}
pub fn is_unnamed_unbound(&self) -> bool {
self.borrow().is_unnamed_unbound()
}
pub fn unsafe_crack(&self) -> &T {
match unsafe { self.as_ptr().as_ref().unwrap() } {
FreeKind::Linked(t) | FreeKind::UndoableLinked { t, .. } => t,

View file

@ -1710,54 +1710,49 @@ impl Type {
}
}
pub fn has_union_type(&self) -> bool {
pub fn union_size(&self) -> usize {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().has_union_type(),
Self::FreeVar(fv) if fv.is_linked() => fv.crack().union_size(),
Self::FreeVar(fv) if fv.constraint_is_sandwiched() => {
let (sub, sup) = fv.get_subsup().unwrap();
fv.dummy_link();
let res = sub.has_union_type() || sup.has_union_type();
let res = sub.union_size().max(sup.union_size());
fv.undo();
res
}
Self::Or(_, _) => true,
Self::Refinement(refine) => refine.t.has_union_type(),
Self::Ref(t) => t.has_union_type(),
Self::RefMut { before, after } => {
before.has_union_type()
|| after.as_ref().map(|t| t.has_union_type()).unwrap_or(false)
}
Self::And(lhs, rhs) => lhs.has_union_type() || rhs.has_union_type(),
Self::Not(ty) => ty.has_union_type(),
Self::Callable { param_ts, return_t } => {
param_ts.iter().any(|t| t.has_union_type()) || return_t.has_union_type()
}
Self::Subr(subr) => {
subr.non_default_params
.iter()
.any(|pt| pt.typ().has_union_type())
|| subr
.var_params
.as_ref()
.map(|pt| pt.typ().has_union_type())
.unwrap_or(false)
|| subr
.default_params
.iter()
.any(|pt| pt.typ().has_union_type())
|| subr.return_t.has_union_type()
}
Self::Record(r) => r.values().any(|t| t.has_union_type()),
Self::Quantified(quant) => quant.has_union_type(),
Self::Poly { params, .. } => params.iter().any(|p| p.has_union_type()),
Self::Proj { lhs, .. } => lhs.has_union_type(),
Self::ProjCall { lhs, args, .. } => {
lhs.has_union_type() || args.iter().any(|t| t.has_union_type())
}
Self::Structural(ty) => ty.has_union_type(),
Self::Guard(guard) => guard.to.has_union_type(),
Self::Bounded { sub, sup } => sub.has_union_type() || sup.has_union_type(),
_ => false,
// Or(Or(Int, Str), Nat) == 3
Self::Or(l, r) => l.union_size() + r.union_size(),
Self::Refinement(refine) => refine.t.union_size(),
Self::Ref(t) => t.union_size(),
Self::RefMut { before, after: _ } => before.union_size(),
Self::And(lhs, rhs) => lhs.union_size().max(rhs.union_size()),
Self::Not(ty) => ty.union_size(),
Self::Callable { param_ts, return_t } => param_ts
.iter()
.map(|t| t.union_size())
.max()
.unwrap_or(1)
.max(return_t.union_size()),
Self::Subr(subr) => subr
.non_default_params
.iter()
.map(|pt| pt.typ().union_size())
.chain(subr.var_params.as_ref().map(|pt| pt.typ().union_size()))
.chain(subr.default_params.iter().map(|pt| pt.typ().union_size()))
.max()
.unwrap_or(1)
.max(subr.return_t.union_size()),
Self::Record(r) => r.values().map(|t| t.union_size()).max().unwrap_or(1),
Self::Quantified(quant) => quant.union_size(),
Self::Poly { params, .. } => params.iter().map(|p| p.union_size()).max().unwrap_or(1),
Self::Proj { lhs, .. } => lhs.union_size(),
Self::ProjCall { lhs, args, .. } => lhs
.union_size()
.max(args.iter().map(|t| t.union_size()).max().unwrap_or(1)),
Self::Structural(ty) => ty.union_size(),
Self::Guard(guard) => guard.to.union_size(),
Self::Bounded { sub, sup } => sub.union_size().max(sup.union_size()),
_ => 1,
}
}
@ -2149,6 +2144,14 @@ impl Type {
}
}
/// assert!((A or B).contains_union(B))
pub fn contains_union(&self, typ: &Type) -> bool {
match self {
Type::Or(t1, t2) => t1.contains_union(typ) || t2.contains_union(typ),
_ => self == typ,
}
}
pub fn intersection_types(&self) -> Vec<Type> {
match self {
Type::FreeVar(fv) if fv.is_linked() => fv.crack().intersection_types(),
@ -2162,14 +2165,6 @@ impl Type {
}
}
/// assert!((A or B).contains_union(B))
pub fn contains_union(&self, typ: &Type) -> bool {
match self {
Type::Or(t1, t2) => t1.contains_union(typ) || t2.contains_union(typ),
_ => self == typ,
}
}
pub fn tvar_name(&self) -> Option<Str> {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().tvar_name(),
@ -2220,6 +2215,13 @@ impl Type {
matches!(self, Self::FreeVar(fv) if fv.is_named_unbound() || (fv.is_linked() && fv.crack().is_named_unbound_var()))
}
pub fn is_unnamed_unbound_var(&self) -> bool {
matches!(self, Self::FreeVar(fv) if fv.is_unnamed_unbound() || (fv.is_linked() && fv.crack().is_unnamed_unbound_var()))
}
/// ```erg
/// assert (?T or ?U).totally_unbound()
/// ```
pub fn is_totally_unbound(&self) -> bool {
match self {
Self::FreeVar(fv) if fv.is_unbound() => true,

View file

@ -1049,24 +1049,33 @@ impl TyParam {
!self.has_unbound_var()
}
pub fn has_union_type(&self) -> bool {
pub fn union_size(&self) -> usize {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().has_union_type(),
Self::Type(t) => t.has_union_type(),
Self::Proj { obj, .. } => obj.has_union_type(),
Self::Array(ts) | Self::Tuple(ts) => ts.iter().any(|t| t.has_union_type()),
Self::Set(ts) => ts.iter().any(|t| t.has_union_type()),
Self::FreeVar(fv) if fv.is_linked() => fv.crack().union_size(),
Self::Type(t) => t.union_size(),
Self::Proj { obj, .. } => obj.union_size(),
Self::Array(ts) | Self::Tuple(ts) => {
ts.iter().map(|t| t.union_size()).max().unwrap_or(1)
}
Self::Set(ts) => ts.iter().map(|t| t.union_size()).max().unwrap_or(1),
Self::Dict(kv) => kv
.iter()
.any(|(k, v)| k.has_union_type() || v.has_union_type()),
Self::Record(rec) => rec.iter().any(|(_, v)| v.has_union_type()),
Self::Lambda(lambda) => lambda.body.iter().any(|t| t.has_union_type()),
Self::UnaryOp { val, .. } => val.has_union_type(),
Self::BinOp { lhs, rhs, .. } => lhs.has_union_type() || rhs.has_union_type(),
Self::App { args, .. } => args.iter().any(|p| p.has_union_type()),
Self::Erased(t) => t.has_union_type(),
Self::Value(ValueObj::Type(t)) => t.typ().has_union_type(),
_ => false,
.map(|(k, v)| k.union_size().max(v.union_size()))
.max()
.unwrap_or(1),
Self::Record(rec) => rec.iter().map(|(_, v)| v.union_size()).max().unwrap_or(1),
Self::Lambda(lambda) => lambda
.body
.iter()
.map(|t| t.union_size())
.max()
.unwrap_or(1),
Self::UnaryOp { val, .. } => val.union_size(),
Self::BinOp { lhs, rhs, .. } => lhs.union_size().max(rhs.union_size()),
Self::App { args, .. } => args.iter().map(|p| p.union_size()).max().unwrap_or(1),
Self::Erased(t) => t.union_size(),
Self::Value(ValueObj::Type(t)) => t.typ().union_size(),
_ => 1,
}
}

View file

@ -5,3 +5,8 @@ w = ![]
w.push! "a"
_ = v.concat w # ERR
i_s = ![1 as (Int or Str)]
i_s.push! "b"
i_s.push! 2
i_s.push! None # ERR

View file

@ -344,7 +344,7 @@ fn exec_mut_err() -> Result<(), ()> {
#[test]
fn exec_mut_array_err() -> Result<(), ()> {
expect_failure("tests/should_err/mut_array.er", 0, 1)
expect_failure("tests/should_err/mut_array.er", 0, 2)
}
#[test]