feat: add Predicate::{Call, GeneralEqual}

This commit is contained in:
Shunsuke Shibayama 2024-01-28 15:30:02 +09:00
parent 2939c740a7
commit 4393649ffc
15 changed files with 746 additions and 22 deletions

View file

@ -614,6 +614,12 @@ impl Context {
return false;
}
}
for tp in r.pred.possible_tps() {
let substituted = l.pred.clone().substitute(&l.var, tp);
if self.bool_eval_pred(substituted).is_ok_and(|b| b) {
return true;
}
}
self.is_super_pred_of(&l.pred, &r.pred)
}
(Nat | Bool, re @ Refinement(_)) => {
@ -633,6 +639,7 @@ impl Context {
// Array({1, 2}, _) :> {[3, 4]} == false
(l, Refinement(r)) => {
// Type / {S: Set(Str) | S == {"a", "b"}}
// TODO: GeneralEq
if let Pred::Equal { rhs, .. } = r.pred.as_ref() {
if self.subtype_of(l, &Type) && self.convert_tp_into_type(rhs.clone()).is_ok() {
return true;
@ -1514,6 +1521,21 @@ impl Context {
| Predicate::LessEqual { rhs, .. } => self.get_tp_t(rhs).unwrap_or(Obj),
Predicate::Not(pred) => self.get_pred_type(pred),
Predicate::Value(val) => val.class(),
Predicate::Call { receiver, name, .. } => {
let receiver_t = self.get_tp_t(receiver).unwrap_or(Obj);
if let Some(name) = name {
let ctx = self.get_nominal_type_ctx(&receiver_t).unwrap();
if let Some((_, method)) = ctx.get_var_info(name) {
method.t.return_t().cloned().unwrap_or(Obj)
} else {
Obj
}
} else {
receiver_t.return_t().cloned().unwrap_or(Obj)
}
}
// REVIEW
Predicate::GeneralEqual { rhs, .. } => self.get_pred_type(rhs),
// x == 1 or x == "a" => Int or Str
Predicate::Or(lhs, rhs) => {
self.union(&self.get_pred_type(lhs), &self.get_pred_type(rhs))
@ -1645,6 +1667,13 @@ impl Context {
.map(|ord| ord.canbe_eq())
.unwrap_or(false)
}
(
Pred::GeneralEqual { lhs, rhs },
Pred::GeneralEqual {
lhs: lhs2,
rhs: rhs2,
},
) => self.is_super_pred_of(lhs, lhs2) && self.is_super_pred_of(rhs, rhs2),
// {T >= 0} :> {T >= 1}, {T >= 0} :> {T == 1}
(
Pred::GreaterEqual { rhs, .. },

View file

@ -2572,9 +2572,43 @@ impl Context {
}
}
pub(crate) fn bool_eval_pred(&self, p: Predicate) -> EvalResult<bool> {
let evaled = self.eval_pred(p)?;
Ok(matches!(evaled, Predicate::Value(ValueObj::Bool(true))))
}
pub(crate) fn eval_pred(&self, p: Predicate) -> EvalResult<Predicate> {
match p {
Predicate::Value(_) | Predicate::Const(_) => Ok(p),
Predicate::Call {
receiver,
name,
args,
} => {
let receiver = self.eval_tp(receiver)?;
let mut new_args = vec![];
for arg in args {
new_args.push(self.eval_tp(arg)?);
}
let t = if let Some(name) = name {
self.eval_proj_call(receiver, name, new_args, &())?
} else {
return feature_error!(self, Location::Unknown, "eval_pred: Predicate::Call");
};
if let TyParam::Value(v) = t {
Ok(Predicate::Value(v))
} else {
feature_error!(self, Location::Unknown, "eval_pred: Predicate::Call")
}
}
Predicate::GeneralEqual { lhs, rhs } => {
match (self.eval_pred(*lhs)?, self.eval_pred(*rhs)?) {
(Predicate::Value(lhs), Predicate::Value(rhs)) => {
Ok(Predicate::Value(ValueObj::Bool(lhs == rhs)))
}
(lhs, rhs) => Ok(Predicate::general_eq(lhs, rhs)),
}
}
Predicate::Equal { lhs, rhs } => Ok(Predicate::eq(lhs, self.eval_tp(rhs)?)),
Predicate::NotEqual { lhs, rhs } => Ok(Predicate::ne(lhs, self.eval_tp(rhs)?)),
Predicate::LessEqual { lhs, rhs } => Ok(Predicate::le(lhs, self.eval_tp(rhs)?)),

View file

@ -323,7 +323,24 @@ impl Generalizer {
*typ.typ_mut() = self.generalize_t(mem::take(typ.typ_mut()), uninit);
Predicate::Value(ValueObj::Type(typ))
}
Predicate::Call {
receiver,
name,
args,
} => {
let receiver = self.generalize_tp(receiver, uninit);
let mut new_args = vec![];
for arg in args.into_iter() {
new_args.push(self.generalize_tp(arg, uninit));
}
Predicate::call(receiver, name, new_args)
}
Predicate::Value(_) => pred,
Predicate::GeneralEqual { lhs, rhs } => {
let lhs = self.generalize_pred(*lhs, uninit);
let rhs = self.generalize_pred(*rhs, uninit);
Predicate::general_eq(lhs, rhs)
}
Predicate::Equal { lhs, rhs } => {
let rhs = self.generalize_tp(rhs, uninit);
Predicate::eq(lhs, rhs)
@ -414,6 +431,11 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
TyParam::FreeVar(_) if self.ctx.level == 0 => {
Ok(TyParam::erased(self.ctx.get_tp_t(&tp).unwrap_or(Type::Obj)))
}
TyParam::FreeVar(fv) if fv.get_type().is_some() => {
let t = self.deref_tyvar(fv.get_type().unwrap())?;
fv.update_type(t);
Ok(TyParam::FreeVar(fv))
}
TyParam::Type(t) => Ok(TyParam::t(self.deref_tyvar(*t)?)),
TyParam::Erased(t) => Ok(TyParam::erased(self.deref_tyvar(*t)?)),
TyParam::App { name, mut args } => {
@ -539,6 +561,78 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
}
}
fn deref_pred(&mut self, pred: Predicate) -> TyCheckResult<Predicate> {
match pred {
Predicate::Equal { lhs, rhs } => {
let rhs = self.deref_tp(rhs)?;
Ok(Predicate::eq(lhs, rhs))
}
Predicate::GreaterEqual { lhs, rhs } => {
let rhs = self.deref_tp(rhs)?;
Ok(Predicate::ge(lhs, rhs))
}
Predicate::LessEqual { lhs, rhs } => {
let rhs = self.deref_tp(rhs)?;
Ok(Predicate::le(lhs, rhs))
}
Predicate::NotEqual { lhs, rhs } => {
let rhs = self.deref_tp(rhs)?;
Ok(Predicate::ne(lhs, rhs))
}
Predicate::GeneralEqual { lhs, rhs } => {
let lhs = self.deref_pred(*lhs)?;
let rhs = self.deref_pred(*rhs)?;
match (lhs, rhs) {
(Predicate::Value(lhs), Predicate::Value(rhs)) => {
Ok(Predicate::Value(ValueObj::Bool(lhs == rhs)))
}
(lhs, rhs) => Ok(Predicate::general_eq(lhs, rhs)),
}
}
Predicate::Call {
receiver,
name,
args,
} => {
let Ok(receiver) = self.deref_tp(receiver.clone()) else {
return Ok(Predicate::call(receiver, name, args));
};
let mut new_args = vec![];
for arg in args.into_iter() {
let Ok(arg) = self.deref_tp(arg) else {
return Ok(Predicate::call(receiver, name, new_args));
};
new_args.push(arg);
}
let evaled = if let Some(name) = &name {
self.ctx
.eval_proj_call(receiver.clone(), name.clone(), new_args.clone(), &())
} else {
return Ok(Predicate::call(receiver, name, new_args));
};
match evaled {
Ok(TyParam::Value(value)) => Ok(Predicate::Value(value)),
_ => Ok(Predicate::call(receiver, name, new_args)),
}
}
Predicate::And(lhs, rhs) => {
let lhs = self.deref_pred(*lhs)?;
let rhs = self.deref_pred(*rhs)?;
Ok(Predicate::and(lhs, rhs))
}
Predicate::Or(lhs, rhs) => {
let lhs = self.deref_pred(*lhs)?;
let rhs = self.deref_pred(*rhs)?;
Ok(Predicate::or(lhs, rhs))
}
Predicate::Not(pred) => {
let pred = self.deref_pred(*pred)?;
Ok(!pred)
}
_ => Ok(pred),
}
}
fn deref_constraint(&mut self, constraint: Constraint) -> TyCheckResult<Constraint> {
match constraint {
Constraint::Sandwiched { sub, sup } => Ok(Constraint::new_sandwiched(
@ -741,9 +835,10 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
Ok(Type::NamedTuple(rec))
}
Type::Refinement(refine) => {
log!(err "deref_tyvar: {} / {}", refine.t, refine.pred);
let t = self.deref_tyvar(*refine.t)?;
// TODO: deref_predicate
Ok(refinement(refine.var, t, *refine.pred))
let pred = self.deref_pred(*refine.pred)?;
Ok(refinement(refine.var, t, pred))
}
Type::And(l, r) => {
let l = self.deref_tyvar(*l)?;
@ -797,6 +892,12 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
fn validate_subsup(&mut self, sub_t: Type, super_t: Type) -> TyCheckResult<Type> {
// TODO: Subr, ...
match (sub_t, super_t) {
/*(sub_t @ Type::Refinement(_), super_t @ Type::Refinement(_)) => {
self.validate_simple_subsup(sub_t, super_t)
}
(Type::Refinement(refine), super_t) => {
self.validate_simple_subsup(*refine.t, super_t)
}*/
// See tests\should_err\subtyping.er:8~13
(
Type::Poly {

View file

@ -1611,6 +1611,32 @@ impl Context {
array_t(T.clone(), TyParam::erased(Nat)),
);
array_.register_py_builtin(FUNC_DEDUP, t.quantify(), Some(FUNC_DEDUP), 28);
let sum_t = no_var_fn_met(
array_t(T.clone(), TyParam::erased(Nat)),
vec![],
vec![kw("start", T.clone())],
T.clone(),
);
let sum = ValueObj::Subr(ConstSubr::Builtin(BuiltinConstSubr::new(
FUNC_SUM,
array_sum,
sum_t.quantify(),
None,
)));
array_.register_builtin_const(FUNC_SUM, Visibility::BUILTIN_PUBLIC, sum);
let prod_t = no_var_fn_met(
array_t(T.clone(), TyParam::erased(Nat)),
vec![],
vec![kw("start", T.clone())],
T.clone(),
);
let prod = ValueObj::Subr(ConstSubr::Builtin(BuiltinConstSubr::new(
FUNC_PROD,
array_prod,
prod_t.quantify(),
None,
)));
array_.register_builtin_const(FUNC_PROD, Visibility::BUILTIN_PUBLIC, prod);
/* Slice */
let mut slice = Self::builtin_mono_class(SLICE, 3);
slice.register_superclass(Obj, &obj);

View file

@ -519,6 +519,106 @@ pub(crate) fn array_shape(mut args: ValueArgs, ctx: &Context) -> EvalValueResult
Ok(arr)
}
fn _array_sum(arr: ValueObj, _ctx: &Context) -> Result<ValueObj, String> {
match arr {
ValueObj::Array(a) => {
let mut sum = 0f64;
for v in a.iter() {
match v {
ValueObj::Nat(n) => {
sum += *n as f64;
}
ValueObj::Int(n) => {
sum += *n as f64;
}
ValueObj::Float(n) => {
sum += *n;
}
ValueObj::Inf => {
return Ok(ValueObj::Inf);
}
ValueObj::NegInf => {
return Ok(ValueObj::NegInf);
}
_ => {
return Err(format!("Cannot sum {v}"));
}
}
}
if sum.round() == sum && sum >= 0.0 {
Ok(ValueObj::Nat(sum as u64))
} else if sum.round() == sum {
Ok(ValueObj::Int(sum as i32))
} else {
Ok(ValueObj::Float(sum))
}
}
_ => Err(format!("Cannot sum {arr}")),
}
}
/// ```erg
/// [1, 2].sum() == [3,]
/// ```
pub(crate) fn array_sum(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
let arr = args
.remove_left_or_key("Self")
.ok_or_else(|| not_passed("Self"))?;
let res = _array_sum(arr, ctx).unwrap();
let arr = TyParam::Value(res);
Ok(arr)
}
fn _array_prod(arr: ValueObj, _ctx: &Context) -> Result<ValueObj, String> {
match arr {
ValueObj::Array(a) => {
let mut prod = 1f64;
for v in a.iter() {
match v {
ValueObj::Nat(n) => {
prod *= *n as f64;
}
ValueObj::Int(n) => {
prod *= *n as f64;
}
ValueObj::Float(n) => {
prod *= *n;
}
ValueObj::Inf => {
return Ok(ValueObj::Inf);
}
ValueObj::NegInf => {
return Ok(ValueObj::NegInf);
}
_ => {
return Err(format!("Cannot prod {v}"));
}
}
}
if prod.round() == prod && prod >= 0.0 {
Ok(ValueObj::Nat(prod as u64))
} else if prod.round() == prod {
Ok(ValueObj::Int(prod as i32))
} else {
Ok(ValueObj::Float(prod))
}
}
_ => Err(format!("Cannot prod {arr}")),
}
}
/// ```erg
/// [1, 2].prod() == [2,]
/// ```
pub(crate) fn array_prod(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
let arr = args
.remove_left_or_key("Self")
.ok_or_else(|| not_passed("Self"))?;
let res = _array_prod(arr, ctx).unwrap();
let arr = TyParam::Value(res);
Ok(arr)
}
pub(crate) fn __range_getitem__(mut args: ValueArgs, _ctx: &Context) -> EvalValueResult<TyParam> {
let slf = args
.remove_left_or_key("Self")

View file

@ -419,6 +419,7 @@ const FUNC_DELATTR: &str = "delattr";
const FUNC_NEARLY_EQ: &str = "nearly_eq";
const FUNC_RESOLVE_PATH: &str = "ResolvePath";
const FUNC_RESOLVE_DECL_PATH: &str = "ResolveDeclPath";
const FUNC_PROD: &str = "prod";
const OP_EQ: &str = "__eq__";
const OP_HASH: &str = "__hash__";

View file

@ -137,10 +137,6 @@ impl TyVarCache {
}
}
fn _instantiate_pred(&self, _pred: Predicate) -> Predicate {
todo!()
}
/// Some of the quantified types are circulating.
/// e.g.
/// ```erg
@ -576,6 +572,23 @@ impl Context {
let value = self.instantiate_value(value, tmp_tv_cache, loc)?;
Ok(Predicate::Value(value))
}
Predicate::Call {
receiver,
name,
args,
} => {
let receiver = self.instantiate_tp(receiver, tmp_tv_cache, loc)?;
let mut new_args = Vec::with_capacity(args.len());
for arg in args {
new_args.push(self.instantiate_tp(arg, tmp_tv_cache, loc)?);
}
Ok(Predicate::call(receiver, name, new_args))
}
Predicate::GeneralEqual { lhs, rhs } => {
let lhs = self.instantiate_pred(*lhs, tmp_tv_cache, loc)?;
let rhs = self.instantiate_pred(*rhs, tmp_tv_cache, loc)?;
Ok(Predicate::general_eq(lhs, rhs))
}
_ => Ok(pred),
}
}

View file

@ -1462,6 +1462,20 @@ impl Context {
self.inc_ref_local(local, self, tmp_tv_cache);
Ok(Predicate::Const(local.inspect().clone()))
}
ast::ConstExpr::App(app) => {
let receiver = self.instantiate_const_expr(&app.obj, None, tmp_tv_cache, false)?;
let name = app.attr_name.as_ref().map(|n| n.inspect().to_owned());
let mut args = vec![];
for arg in app.args.pos_args() {
let arg = self.instantiate_const_expr(&arg.expr, None, tmp_tv_cache, false)?;
args.push(arg);
}
Ok(Predicate::Call {
receiver,
name,
args,
})
}
ast::ConstExpr::BinOp(bin) => {
let lhs = self.instantiate_pred_from_expr(&bin.lhs, tmp_tv_cache)?;
let rhs = self.instantiate_pred_from_expr(&bin.rhs, tmp_tv_cache)?;
@ -1472,16 +1486,25 @@ impl Context {
| TokenKind::LessEq
| TokenKind::Gre
| TokenKind::GreEq => {
let Predicate::Const(var) = lhs else {
return type_feature_error!(
self,
bin.loc(),
&format!("instantiating predicate `{expr}`")
);
let var = match lhs {
Predicate::Const(var) => var,
other if bin.op.kind == TokenKind::DblEq => {
return Ok(Predicate::general_eq(other, rhs));
}
_ => {
return type_feature_error!(
self,
bin.loc(),
&format!("instantiating predicate `{expr}`")
);
}
};
let rhs = match rhs {
Predicate::Value(value) => TyParam::Value(value),
Predicate::Const(var) => TyParam::Mono(var),
other if bin.op.kind == TokenKind::DblEq => {
return Ok(Predicate::general_eq(Predicate::Const(var), other));
}
_ => {
return type_feature_error!(
self,