From acb3eac043d5fbbb56f92d1042fccda372bebc8d Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 10 Dec 2022 13:18:07 +0900 Subject: [PATCH] Fix a refinement subtype checking bug --- compiler/erg_compiler/context/compare.rs | 48 +++++------------ compiler/erg_compiler/context/tyvar.rs | 2 +- compiler/erg_compiler/ty/mod.rs | 68 ++++++++++++++++++++++-- 3 files changed, 77 insertions(+), 41 deletions(-) diff --git a/compiler/erg_compiler/context/compare.rs b/compiler/erg_compiler/context/compare.rs index b6858530..d087cc04 100644 --- a/compiler/erg_compiler/context/compare.rs +++ b/compiler/erg_compiler/context/compare.rs @@ -9,12 +9,10 @@ use crate::ty::typaram::{OpKind, TyParam, TyParamOrdering}; use crate::ty::value::ValueObj; use crate::ty::value::ValueObj::Inf; use crate::ty::{Predicate, RefinementType, SubrKind, SubrType, Type}; -use erg_common::fresh::fresh_varname; use Predicate as Pred; use erg_common::dict::Dict; -use erg_common::Str; -use erg_common::{assume_unreachable, log, set}; +use erg_common::{assume_unreachable, log}; use TyParamOrdering::*; use Type::*; @@ -501,10 +499,15 @@ impl Context { // ({I: Int | I >= 0} :> {I: Int | I >= 1}) == true, // ({I: Int | I >= 0} :> {N: Nat | N >= 1}) == true, // ({I: Int | I > 1 or I < -1} :> {I: Int | I >= 0}) == false, + // ({I: Int | I >= 0} :> {F: Float | F >= 0}) == false, // {1, 2, 3} :> {1, } == true (Refinement(l), Refinement(r)) => { - if !self.supertype_of(&l.t, &r.t) && !self.supertype_of(&r.t, &l.t) { - return false; + match (self.subtype_of(&l.t, &r.t), self.supertype_of(&l.t, &r.t)) { + // no relation + (false, false) + // l.t <: r.t (not equal) + | (true, false) => { return false; } + _ => {} } let mut r_preds_clone = r.preds.clone(); for l_pred in l.preds.iter() { @@ -520,11 +523,11 @@ impl Context { r_preds_clone.is_empty() } (Nat, re @ Refinement(_)) => { - let nat = Type::Refinement(self.into_refinement(Nat)); + let nat = Type::Refinement(Nat.into_refinement()); self.structural_supertype_of(&nat, re) } (re @ Refinement(_), Nat) => { - let nat = Type::Refinement(self.into_refinement(Nat)); + let nat = Type::Refinement(Nat.into_refinement()); self.structural_supertype_of(re, &nat) } // Int :> {I: Int | ...} == true @@ -541,7 +544,7 @@ impl Context { if self.supertype_of(&l, &r.t) { return true; } - let l = Type::Refinement(self.into_refinement(l)); + let l = Type::Refinement(l.into_refinement()); self.structural_supertype_of(&l, rhs) } // ({I: Int | True} :> Int) == true, ({N: Nat | ...} :> Int) == false, ({I: Int | I >= 0} :> Int) == false @@ -835,33 +838,6 @@ impl Context { } } - #[allow(clippy::wrong_self_convention)] - pub(crate) fn into_refinement(&self, t: Type) -> RefinementType { - match t { - Nat => { - let var = Str::from(fresh_varname()); - RefinementType::new( - var.clone(), - Int, - set! {Predicate::ge(var, TyParam::value(0))}, - ) - } - Bool => { - let var = Str::from(fresh_varname()); - RefinementType::new( - var.clone(), - Int, - set! {Predicate::ge(var.clone(), TyParam::value(true)), Predicate::le(var, TyParam::value(false))}, - ) - } - Refinement(r) => r, - t => { - let var = Str::from(fresh_varname()); - RefinementType::new(var, t, set! {}) - } - } - } - /// returns union of two types (A or B) pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type { if lhs == rhs { @@ -888,7 +864,7 @@ impl Context { (other, Refinement(refine)) | (Refinement(refine), other) if !other.is_unbound_var() => { - let other = self.into_refinement(other.clone()); + let other = other.clone().into_refinement(); Type::Refinement(self.union_refinement(&other, refine)) } // Array({1, 2}, 2), Array({3, 4}, 2) ==> Array({1, 2, 3, 4}, 2) diff --git a/compiler/erg_compiler/context/tyvar.rs b/compiler/erg_compiler/context/tyvar.rs index f21617a6..8e6a7eed 100644 --- a/compiler/erg_compiler/context/tyvar.rs +++ b/compiler/erg_compiler/context/tyvar.rs @@ -1592,7 +1592,7 @@ impl Context { }, // {I: Int | I >= 1} <: Nat == {I: Int | I >= 0} (Type::Refinement(_), sup) => { - let sup = self.into_refinement(sup.clone()); + let sup = sup.clone().into_refinement(); self.sub_unify(maybe_sub, &Type::Refinement(sup), loc, param_name) }, (Type::Subr(_) | Type::Record(_), Type) => Ok(()), diff --git a/compiler/erg_compiler/ty/mod.rs b/compiler/erg_compiler/ty/mod.rs index 4066dcba..5b027046 100644 --- a/compiler/erg_compiler/ty/mod.rs +++ b/compiler/erg_compiler/ty/mod.rs @@ -836,12 +836,29 @@ impl LimitedDisplay for RefinementType { impl RefinementType { pub fn new(var: Str, t: Type, preds: Set) -> Self { - Self { - var, - t: Box::new(t), - preds, + match t.deconstruct_refinement() { + Ok((inner_var, inner_t, inner_preds)) => { + let new_preds = preds + .into_iter() + .map(|pred| pred.change_subject_name(inner_var.clone())) + .collect::>(); + Self { + var: inner_var, + t: Box::new(inner_t), + preds: inner_preds.concat(new_preds), + } + } + Err(t) => Self { + var, + t: Box::new(t), + preds, + }, } } + + pub fn deconstruct(self) -> (Str, Type, Set) { + (self.var, *self.t, self.preds) + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -1611,6 +1628,14 @@ impl Type { } } + pub fn is_refinement(&self) -> bool { + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_refinement(), + Self::Refinement(_) => true, + _ => false, + } + } + pub fn is_record(&self) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_record(), @@ -1852,6 +1877,41 @@ impl Type { || (self.has_no_qvar() && self.has_no_unbound_var()) } + pub fn into_refinement(self) -> RefinementType { + match self { + Type::FreeVar(fv) if fv.is_linked() => fv.crack().clone().into_refinement(), + Type::Nat => { + let var = Str::from(fresh_varname()); + RefinementType::new( + var.clone(), + Type::Int, + set! {Predicate::ge(var, TyParam::value(0))}, + ) + } + Type::Bool => { + let var = Str::from(fresh_varname()); + RefinementType::new( + var.clone(), + Type::Int, + set! {Predicate::ge(var.clone(), TyParam::value(true)), Predicate::le(var, TyParam::value(false))}, + ) + } + Type::Refinement(r) => r, + t => { + let var = Str::from(fresh_varname()); + RefinementType::new(var, t, set! {}) + } + } + } + + pub fn deconstruct_refinement(self) -> Result<(Str, Type, Set), Type> { + match self { + Type::FreeVar(fv) if fv.is_linked() => fv.crack().clone().deconstruct_refinement(), + Type::Refinement(r) => Ok(r.deconstruct()), + _ => Err(self), + } + } + pub fn qvars(&self) -> Set<(Str, Constraint)> { match self { Self::FreeVar(fv) if fv.is_linked() => fv.forced_as_ref().linked().unwrap().qvars(),