Fix a refinement subtype checking bug

This commit is contained in:
Shunsuke Shibayama 2022-12-10 13:18:07 +09:00
parent 747974d37c
commit acb3eac043
3 changed files with 77 additions and 41 deletions

View file

@ -9,12 +9,10 @@ use crate::ty::typaram::{OpKind, TyParam, TyParamOrdering};
use crate::ty::value::ValueObj;
use crate::ty::value::ValueObj::Inf;
use crate::ty::{Predicate, RefinementType, SubrKind, SubrType, Type};
use erg_common::fresh::fresh_varname;
use Predicate as Pred;
use erg_common::dict::Dict;
use erg_common::Str;
use erg_common::{assume_unreachable, log, set};
use erg_common::{assume_unreachable, log};
use TyParamOrdering::*;
use Type::*;
@ -501,10 +499,15 @@ impl Context {
// ({I: Int | I >= 0} :> {I: Int | I >= 1}) == true,
// ({I: Int | I >= 0} :> {N: Nat | N >= 1}) == true,
// ({I: Int | I > 1 or I < -1} :> {I: Int | I >= 0}) == false,
// ({I: Int | I >= 0} :> {F: Float | F >= 0}) == false,
// {1, 2, 3} :> {1, } == true
(Refinement(l), Refinement(r)) => {
if !self.supertype_of(&l.t, &r.t) && !self.supertype_of(&r.t, &l.t) {
return false;
match (self.subtype_of(&l.t, &r.t), self.supertype_of(&l.t, &r.t)) {
// no relation
(false, false)
// l.t <: r.t (not equal)
| (true, false) => { return false; }
_ => {}
}
let mut r_preds_clone = r.preds.clone();
for l_pred in l.preds.iter() {
@ -520,11 +523,11 @@ impl Context {
r_preds_clone.is_empty()
}
(Nat, re @ Refinement(_)) => {
let nat = Type::Refinement(self.into_refinement(Nat));
let nat = Type::Refinement(Nat.into_refinement());
self.structural_supertype_of(&nat, re)
}
(re @ Refinement(_), Nat) => {
let nat = Type::Refinement(self.into_refinement(Nat));
let nat = Type::Refinement(Nat.into_refinement());
self.structural_supertype_of(re, &nat)
}
// Int :> {I: Int | ...} == true
@ -541,7 +544,7 @@ impl Context {
if self.supertype_of(&l, &r.t) {
return true;
}
let l = Type::Refinement(self.into_refinement(l));
let l = Type::Refinement(l.into_refinement());
self.structural_supertype_of(&l, rhs)
}
// ({I: Int | True} :> Int) == true, ({N: Nat | ...} :> Int) == false, ({I: Int | I >= 0} :> Int) == false
@ -835,33 +838,6 @@ impl Context {
}
}
#[allow(clippy::wrong_self_convention)]
pub(crate) fn into_refinement(&self, t: Type) -> RefinementType {
match t {
Nat => {
let var = Str::from(fresh_varname());
RefinementType::new(
var.clone(),
Int,
set! {Predicate::ge(var, TyParam::value(0))},
)
}
Bool => {
let var = Str::from(fresh_varname());
RefinementType::new(
var.clone(),
Int,
set! {Predicate::ge(var.clone(), TyParam::value(true)), Predicate::le(var, TyParam::value(false))},
)
}
Refinement(r) => r,
t => {
let var = Str::from(fresh_varname());
RefinementType::new(var, t, set! {})
}
}
}
/// returns union of two types (A or B)
pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type {
if lhs == rhs {
@ -888,7 +864,7 @@ impl Context {
(other, Refinement(refine)) | (Refinement(refine), other)
if !other.is_unbound_var() =>
{
let other = self.into_refinement(other.clone());
let other = other.clone().into_refinement();
Type::Refinement(self.union_refinement(&other, refine))
}
// Array({1, 2}, 2), Array({3, 4}, 2) ==> Array({1, 2, 3, 4}, 2)

View file

@ -1592,7 +1592,7 @@ impl Context {
},
// {I: Int | I >= 1} <: Nat == {I: Int | I >= 0}
(Type::Refinement(_), sup) => {
let sup = self.into_refinement(sup.clone());
let sup = sup.clone().into_refinement();
self.sub_unify(maybe_sub, &Type::Refinement(sup), loc, param_name)
},
(Type::Subr(_) | Type::Record(_), Type) => Ok(()),

View file

@ -836,12 +836,29 @@ impl LimitedDisplay for RefinementType {
impl RefinementType {
pub fn new(var: Str, t: Type, preds: Set<Predicate>) -> Self {
match t.deconstruct_refinement() {
Ok((inner_var, inner_t, inner_preds)) => {
let new_preds = preds
.into_iter()
.map(|pred| pred.change_subject_name(inner_var.clone()))
.collect::<Set<_>>();
Self {
var: inner_var,
t: Box::new(inner_t),
preds: inner_preds.concat(new_preds),
}
}
Err(t) => Self {
var,
t: Box::new(t),
preds,
},
}
}
pub fn deconstruct(self) -> (Str, Type, Set<Predicate>) {
(self.var, *self.t, self.preds)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@ -1611,6 +1628,14 @@ impl Type {
}
}
pub fn is_refinement(&self) -> bool {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_refinement(),
Self::Refinement(_) => true,
_ => false,
}
}
pub fn is_record(&self) -> bool {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_record(),
@ -1852,6 +1877,41 @@ impl Type {
|| (self.has_no_qvar() && self.has_no_unbound_var())
}
pub fn into_refinement(self) -> RefinementType {
match self {
Type::FreeVar(fv) if fv.is_linked() => fv.crack().clone().into_refinement(),
Type::Nat => {
let var = Str::from(fresh_varname());
RefinementType::new(
var.clone(),
Type::Int,
set! {Predicate::ge(var, TyParam::value(0))},
)
}
Type::Bool => {
let var = Str::from(fresh_varname());
RefinementType::new(
var.clone(),
Type::Int,
set! {Predicate::ge(var.clone(), TyParam::value(true)), Predicate::le(var, TyParam::value(false))},
)
}
Type::Refinement(r) => r,
t => {
let var = Str::from(fresh_varname());
RefinementType::new(var, t, set! {})
}
}
}
pub fn deconstruct_refinement(self) -> Result<(Str, Type, Set<Predicate>), Type> {
match self {
Type::FreeVar(fv) if fv.is_linked() => fv.crack().clone().deconstruct_refinement(),
Type::Refinement(r) => Ok(r.deconstruct()),
_ => Err(self),
}
}
pub fn qvars(&self) -> Set<(Str, Constraint)> {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.forced_as_ref().linked().unwrap().qvars(),