feat: type narrowing with filter

This commit is contained in:
Shunsuke Shibayama 2024-03-21 01:29:27 +09:00
parent 5345b07791
commit bd39393746
9 changed files with 209 additions and 36 deletions

View file

@ -377,7 +377,11 @@ impl Context {
let same_params_len = ls.non_default_params.len() == rs.non_default_params.len()
|| rs.var_params.is_some();
// && ls.default_params.len() <= rs.default_params.len();
let return_t_judge = self.supertype_of(&ls.return_t, &rs.return_t); // covariant
let rhs_ret = rs
.return_t
.clone()
.replace_params(rs.param_names(), ls.param_names());
let return_t_judge = self.supertype_of(&ls.return_t, &rhs_ret); // covariant
let non_defaults_judge = ls
.non_default_params
.iter()
@ -538,6 +542,10 @@ impl Context {
true
}
(Bool, Guard { .. }) => true,
(Guard(lhs), Guard(rhs)) => {
!rhs.to.is_refinement() // TODO: refinement guard is unstable
&& lhs.target == rhs.target && self.supertype_of(&lhs.to, &rhs.to)
}
(Mono(n), NamedTuple(_)) => &n[..] == "GenericNamedTuple" || &n[..] == "GenericTuple",
(Mono(n), Record(_)) => &n[..] == "Record",
(ty @ (Type | ClassType | TraitType), Record(rec)) => {
@ -1500,6 +1508,30 @@ impl Context {
}
}
pub(crate) fn union_pred(&self, lhs: Predicate, rhs: Predicate) -> Predicate {
match (
self.is_super_pred_of(&lhs, &rhs),
self.is_sub_pred_of(&lhs, &rhs),
) {
(true, true) => lhs, // lhs = rhs
(true, false) => lhs, // lhs :> rhs
(false, true) => rhs,
(false, false) => lhs | rhs,
}
}
pub(crate) fn intersection_pred(&self, lhs: Predicate, rhs: Predicate) -> Predicate {
match (
self.is_super_pred_of(&lhs, &rhs),
self.is_sub_pred_of(&lhs, &rhs),
) {
(true, true) => lhs,
(true, false) => rhs,
(false, true) => lhs,
(false, false) => lhs & rhs,
}
}
pub(crate) fn union_refinement(
&self,
lhs: &RefinementType,
@ -1509,8 +1541,8 @@ impl Context {
let union = self.union(&lhs.t, &rhs.t);
let name = lhs.var.clone();
let rhs_pred = rhs.pred.clone().change_subject_name(name);
// FIXME: predの包含関係も考慮する
RefinementType::new(lhs.var.clone(), union, *lhs.pred.clone() | rhs_pred)
let union_pred = self.union_pred(*lhs.pred.clone(), rhs_pred);
RefinementType::new(lhs.var.clone(), union, union_pred)
}
/// Returns intersection of two types (`A and B`).
@ -1614,6 +1646,7 @@ impl Context {
/// ```erg
/// {I: Int | I > 0} and {I: Int | I < 10} == {I: Int | I > 0 and I < 10}
/// {x: Int or NoneType | True} and {x: Obj | x != None} == {x: Int or NoneType | x != None} (== Int)
/// {x: Nat or None | x == 1 or x == None} and {x: Int | True} == {x: Int | x == 1}
/// ```
fn intersection_refinement(
&self,
@ -1623,7 +1656,13 @@ impl Context {
let intersec = self.intersection(&lhs.t, &rhs.t);
let name = lhs.var.clone();
let rhs_pred = rhs.pred.clone().change_subject_name(name);
RefinementType::new(lhs.var.clone(), intersec, *lhs.pred.clone() & rhs_pred)
let intersection_pred = self.intersection_pred(*lhs.pred.clone(), rhs_pred);
let Some(pred) =
self.eliminate_type_mismatched_preds(&lhs.var, &intersec, intersection_pred)
else {
return RefinementType::new(lhs.var.clone(), intersec, Predicate::TRUE);
};
RefinementType::new(lhs.var.clone(), intersec, pred)
}
/// ```erg
@ -1784,6 +1823,50 @@ impl Context {
reduced
}
fn eliminate_type_mismatched_preds(
&self,
var: &str,
t: &Type,
pred: Predicate,
) -> Option<Predicate> {
match pred {
Predicate::Equal { ref lhs, ref rhs }
| Predicate::NotEqual { ref lhs, ref rhs }
| Predicate::GreaterEqual { ref lhs, ref rhs }
| Predicate::LessEqual { ref lhs, ref rhs }
if lhs == var =>
{
let rhs_t = self.get_tp_t(rhs).unwrap_or(Obj);
if !self.subtype_of(&rhs_t, t) {
None
} else {
Some(pred)
}
}
Predicate::And(l, r) => {
let l = self.eliminate_type_mismatched_preds(var, t, *l);
let r = self.eliminate_type_mismatched_preds(var, t, *r);
match (l, r) {
(Some(l), Some(r)) => Some(l & r),
(Some(l), None) => Some(l),
(None, Some(r)) => Some(r),
(None, None) => None,
}
}
Predicate::Or(l, r) => {
let l = self.eliminate_type_mismatched_preds(var, t, *l);
let r = self.eliminate_type_mismatched_preds(var, t, *r);
match (l, r) {
(Some(l), Some(r)) => Some(l | r),
(Some(l), None) => Some(l),
(None, Some(r)) => Some(r),
(None, None) => None,
}
}
_ => Some(pred),
}
}
/// see doc/LANG/compiler/refinement_subtyping.md
/// ```python
/// assert is_super_pred({I >= 0}, {I == 0})

View file

@ -20,7 +20,7 @@ use erg_parser::desugar::Desugarer;
use erg_parser::token::{Token, TokenKind};
use crate::ty::constructors::{
array_t, bounded, closed_range, dict_t, func, mono, mono_q, named_free_var, poly, proj,
array_t, bounded, closed_range, dict_t, func, guard, mono, mono_q, named_free_var, poly, proj,
proj_call, ref_, ref_mut, refinement, set_t, subr_t, subtypeof, tp_enum, try_v_enum, tuple_t,
unknown_len_array_t, v_enum,
};
@ -171,8 +171,10 @@ impl<'c> Substituter<'c> {
let mut stps = st.typarams();
// Or, And are commutative, choose fitting order
if qt.qual_name() == st.qual_name() && (st.qual_name() == "Or" || st.qual_name() == "And") {
// REVIEW: correct condition?
if ctx.covariant_supertype_of_tp(&qtps[0], &stps[1])
&& ctx.covariant_supertype_of_tp(&qtps[1], &stps[0])
&& qt != st
{
stps.swap(0, 1);
}
@ -1843,6 +1845,10 @@ impl Context {
};
Ok(bounded(sub, sup))
}
Type::Guard(grd) => {
let to = self.eval_t_params(*grd.to, level, t_loc)?;
Ok(guard(grd.namespace, grd.target, to))
}
other if other.is_monomorphic() => Ok(other),
other => feature_error!(self, t_loc.loc(), &format!("eval {other}"))
.map_err(|errs| (other, errs)),

View file

@ -298,6 +298,10 @@ impl Generalizer {
res
}
}
Guard(grd) => {
let to = self.generalize_t(*grd.to, uninit);
guard(grd.namespace, grd.target, to)
}
// REVIEW: その他何でもそのまま通していいのか?
other => other,
}
@ -730,11 +734,11 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
// ```
pub(crate) fn deref_tyvar(&mut self, t: Type) -> TyCheckResult<Type> {
match t {
Type::FreeVar(fv) if fv.is_linked() => {
FreeVar(fv) if fv.is_linked() => {
let t = fv.unwrap_linked();
self.deref_tyvar(t)
}
Type::FreeVar(mut fv)
FreeVar(mut fv)
if fv.is_generalized() && self.qnames.contains(&fv.unbound_name().unwrap()) =>
{
fv.update_init();
@ -745,7 +749,7 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
// ?T(<: Int, :> Add(?T)) ==> Int
// ?T(:> Nat, <: Sub(Str)) ==> Error!
// ?T(:> {1, "a"}, <: Eq(?T(:> {1, "a"}, ...)) ==> Error!
Type::FreeVar(fv) if fv.constraint_is_sandwiched() => {
FreeVar(fv) if fv.constraint_is_sandwiched() => {
let (sub_t, super_t) = fv.get_subsup().unwrap();
if self.level <= fv.level().unwrap() {
// we need to force linking to avoid infinite loop
@ -792,7 +796,7 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
Ok(Type::FreeVar(fv))
}
}
Type::FreeVar(fv) if fv.is_unbound() => {
FreeVar(fv) if fv.is_unbound() => {
if self.level == 0 {
match &*fv.crack_constraint() {
Constraint::TypeOf(t) if !t.is_type() => {
@ -812,7 +816,7 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
Ok(Type::FreeVar(fv))
}
}
Type::Poly { name, mut params } => {
Poly { name, mut params } => {
let typ = poly(&name, params.clone());
let ctx = self.ctx.get_nominal_type_ctx(&typ).ok_or_else(|| {
TyCheckError::type_not_found(
@ -834,7 +838,7 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
}
Ok(Type::Poly { name, params })
}
Type::Subr(mut subr) => {
Subr(mut subr) => {
for param in subr.non_default_params.iter_mut() {
self.push_variance(Contravariant);
*param.typ_mut() =
@ -874,7 +878,7 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
self.pop_variance();
Ok(Type::Subr(subr))
}
Type::Callable {
Callable {
mut param_ts,
return_t,
} => {
@ -884,12 +888,12 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
let return_t = self.deref_tyvar(*return_t)?;
Ok(callable(param_ts, return_t))
}
Type::Quantified(subr) => self.eliminate_needless_quant(*subr),
Type::Ref(t) => {
Quantified(subr) => self.eliminate_needless_quant(*subr),
Ref(t) => {
let t = self.deref_tyvar(*t)?;
Ok(ref_(t))
}
Type::RefMut { before, after } => {
RefMut { before, after } => {
let before = self.deref_tyvar(*before)?;
let after = if let Some(after) = after {
Some(self.deref_tyvar(*after)?)
@ -898,38 +902,38 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
};
Ok(ref_mut(before, after))
}
Type::Record(mut rec) => {
Record(mut rec) => {
for (_, field) in rec.iter_mut() {
*field = self.deref_tyvar(mem::take(field))?;
}
Ok(Type::Record(rec))
}
Type::NamedTuple(mut rec) => {
NamedTuple(mut rec) => {
for (_, t) in rec.iter_mut() {
*t = self.deref_tyvar(mem::take(t))?;
}
Ok(Type::NamedTuple(rec))
}
Type::Refinement(refine) => {
Refinement(refine) => {
let t = self.deref_tyvar(*refine.t)?;
let pred = self.deref_pred(*refine.pred)?;
Ok(refinement(refine.var, t, pred))
}
Type::And(l, r) => {
And(l, r) => {
let l = self.deref_tyvar(*l)?;
let r = self.deref_tyvar(*r)?;
Ok(self.ctx.intersection(&l, &r))
}
Type::Or(l, r) => {
Or(l, r) => {
let l = self.deref_tyvar(*l)?;
let r = self.deref_tyvar(*r)?;
Ok(self.ctx.union(&l, &r))
}
Type::Not(ty) => {
Not(ty) => {
let ty = self.deref_tyvar(*ty)?;
Ok(self.ctx.complement(&ty))
}
Type::Proj { lhs, rhs } => {
Proj { lhs, rhs } => {
let proj = self
.ctx
.eval_proj(*lhs.clone(), rhs.clone(), self.level, self.loc)
@ -940,7 +944,7 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
.unwrap_or(Failure);
Ok(proj)
}
Type::ProjCall {
ProjCall {
lhs,
attr_name,
args,
@ -956,10 +960,14 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
.unwrap_or(Failure);
Ok(proj)
}
Type::Structural(inner) => {
Structural(inner) => {
let inner = self.deref_tyvar(*inner)?;
Ok(inner.structuralize())
}
Guard(grd) => {
let to = self.deref_tyvar(*grd.to)?;
Ok(guard(grd.namespace, grd.target, to))
}
t => Ok(t),
}
}
@ -975,11 +983,11 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
}*/
// See tests\should_err\subtyping.er:8~13
(
Type::Poly {
Poly {
name: ln,
params: lps,
},
Type::Poly {
Poly {
name: rn,
params: rps,
},
@ -1565,7 +1573,7 @@ impl Context {
/// ```
pub(crate) fn squash_tyvar(&self, typ: Type) -> Type {
match typ {
Type::Or(l, r) => {
Or(l, r) => {
let l = self.squash_tyvar(*l);
let r = self.squash_tyvar(*r);
// REVIEW:
@ -1582,7 +1590,7 @@ impl Context {
}
self.union(&l, &r)
}
Type::FreeVar(ref fv) if fv.constraint_is_sandwiched() => {
FreeVar(ref fv) if fv.constraint_is_sandwiched() => {
let (sub_t, super_t) = fv.get_subsup().unwrap();
let sub_t = self.squash_tyvar(sub_t);
let super_t = self.squash_tyvar(super_t);

View file

@ -5,7 +5,7 @@ use erg_common::log;
use crate::ty::constructors::*;
use crate::ty::typaram::TyParam;
use crate::ty::value::ValueObj;
use crate::ty::{Field, Type, Visibility};
use crate::ty::{CastTarget, Field, GuardType, Type, Visibility};
use Type::*;
use crate::context::initialize::*;
@ -114,15 +114,29 @@ impl Context {
poly(ENUMERATE, vec![ty_tp(T.clone())]),
)
.quantify();
let guard = Type::Guard(GuardType::new(
"<builtins>".into(),
CastTarget::arg(0, "x".into(), Location::Unknown),
U.clone(),
));
let t_filter = nd_func(
vec![
kw(KW_FUNC, nd_func(vec![anon(T.clone())], None, Bool)),
kw(KW_FUNC, nd_func(vec![kw("x", T.clone())], None, guard)),
kw(KW_ITERABLE, poly(ITERABLE, vec![ty_tp(T.clone())])),
],
None,
poly(FILTER, vec![ty_tp(T.clone())]),
poly(FILTER, vec![ty_tp(T.clone() & U.clone())]),
)
.quantify();
.quantify()
& nd_func(
vec![
kw(KW_FUNC, nd_func(vec![anon(T.clone())], None, Bool)),
kw(KW_ITERABLE, poly(ITERABLE, vec![ty_tp(T.clone())])),
],
None,
poly(FILTER, vec![ty_tp(T.clone())]),
)
.quantify();
let filter = ValueObj::Subr(ConstSubr::Builtin(BuiltinConstSubr::new(
FUNC_FILTER,
filter_func,

View file

@ -16,6 +16,7 @@ use crate::ty::constructors::*;
use crate::ty::free::{Constraint, FreeTyParam, FreeTyVar, HasLevel, GENERIC_LEVEL};
use crate::ty::typaram::{TyParam, TyParamLambda};
use crate::ty::ConstSubr;
use crate::ty::GuardType;
use crate::ty::ValueObj;
use crate::ty::{HasType, Predicate, Type};
use crate::{type_feature_error, unreachable_error};
@ -919,6 +920,14 @@ impl Context {
let ty = self.instantiate_t_inner(*ty, tmp_tv_cache, loc)?;
Ok(self.complement(&ty))
}
Guard(guard) => {
let to = self.instantiate_t_inner(*guard.to, tmp_tv_cache, loc)?;
Ok(Type::Guard(GuardType::new(
guard.namespace,
guard.target,
to,
)))
}
other if other.is_monomorphic() => Ok(other),
other => type_feature_error!(self, loc.loc(), &format!("instantiating type {other}")),
}

View file

@ -1377,6 +1377,9 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
(Structural(sub), Structural(sup)) => {
self.sub_unify(sub, sup)?;
}
(Guard(sub), Guard(sup)) => {
self.sub_unify(&sub.to, &sup.to)?;
}
(sub, Structural(sup)) => {
let sub_fields = self.ctx.fields(sub);
for (sup_field, sup_ty) in self.ctx.fields(sup) {