mirror of
https://github.com/erg-lang/erg.git
synced 2025-08-04 02:39:20 +00:00
feat: add Predicate::{Call, GeneralEqual}
This commit is contained in:
parent
2939c740a7
commit
4393649ffc
15 changed files with 746 additions and 22 deletions
|
@ -614,6 +614,12 @@ impl Context {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
for tp in r.pred.possible_tps() {
|
||||
let substituted = l.pred.clone().substitute(&l.var, tp);
|
||||
if self.bool_eval_pred(substituted).is_ok_and(|b| b) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
self.is_super_pred_of(&l.pred, &r.pred)
|
||||
}
|
||||
(Nat | Bool, re @ Refinement(_)) => {
|
||||
|
@ -633,6 +639,7 @@ impl Context {
|
|||
// Array({1, 2}, _) :> {[3, 4]} == false
|
||||
(l, Refinement(r)) => {
|
||||
// Type / {S: Set(Str) | S == {"a", "b"}}
|
||||
// TODO: GeneralEq
|
||||
if let Pred::Equal { rhs, .. } = r.pred.as_ref() {
|
||||
if self.subtype_of(l, &Type) && self.convert_tp_into_type(rhs.clone()).is_ok() {
|
||||
return true;
|
||||
|
@ -1514,6 +1521,21 @@ impl Context {
|
|||
| Predicate::LessEqual { rhs, .. } => self.get_tp_t(rhs).unwrap_or(Obj),
|
||||
Predicate::Not(pred) => self.get_pred_type(pred),
|
||||
Predicate::Value(val) => val.class(),
|
||||
Predicate::Call { receiver, name, .. } => {
|
||||
let receiver_t = self.get_tp_t(receiver).unwrap_or(Obj);
|
||||
if let Some(name) = name {
|
||||
let ctx = self.get_nominal_type_ctx(&receiver_t).unwrap();
|
||||
if let Some((_, method)) = ctx.get_var_info(name) {
|
||||
method.t.return_t().cloned().unwrap_or(Obj)
|
||||
} else {
|
||||
Obj
|
||||
}
|
||||
} else {
|
||||
receiver_t.return_t().cloned().unwrap_or(Obj)
|
||||
}
|
||||
}
|
||||
// REVIEW
|
||||
Predicate::GeneralEqual { rhs, .. } => self.get_pred_type(rhs),
|
||||
// x == 1 or x == "a" => Int or Str
|
||||
Predicate::Or(lhs, rhs) => {
|
||||
self.union(&self.get_pred_type(lhs), &self.get_pred_type(rhs))
|
||||
|
@ -1645,6 +1667,13 @@ impl Context {
|
|||
.map(|ord| ord.canbe_eq())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
(
|
||||
Pred::GeneralEqual { lhs, rhs },
|
||||
Pred::GeneralEqual {
|
||||
lhs: lhs2,
|
||||
rhs: rhs2,
|
||||
},
|
||||
) => self.is_super_pred_of(lhs, lhs2) && self.is_super_pred_of(rhs, rhs2),
|
||||
// {T >= 0} :> {T >= 1}, {T >= 0} :> {T == 1}
|
||||
(
|
||||
Pred::GreaterEqual { rhs, .. },
|
||||
|
|
|
@ -2572,9 +2572,43 @@ impl Context {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn bool_eval_pred(&self, p: Predicate) -> EvalResult<bool> {
|
||||
let evaled = self.eval_pred(p)?;
|
||||
Ok(matches!(evaled, Predicate::Value(ValueObj::Bool(true))))
|
||||
}
|
||||
|
||||
pub(crate) fn eval_pred(&self, p: Predicate) -> EvalResult<Predicate> {
|
||||
match p {
|
||||
Predicate::Value(_) | Predicate::Const(_) => Ok(p),
|
||||
Predicate::Call {
|
||||
receiver,
|
||||
name,
|
||||
args,
|
||||
} => {
|
||||
let receiver = self.eval_tp(receiver)?;
|
||||
let mut new_args = vec![];
|
||||
for arg in args {
|
||||
new_args.push(self.eval_tp(arg)?);
|
||||
}
|
||||
let t = if let Some(name) = name {
|
||||
self.eval_proj_call(receiver, name, new_args, &())?
|
||||
} else {
|
||||
return feature_error!(self, Location::Unknown, "eval_pred: Predicate::Call");
|
||||
};
|
||||
if let TyParam::Value(v) = t {
|
||||
Ok(Predicate::Value(v))
|
||||
} else {
|
||||
feature_error!(self, Location::Unknown, "eval_pred: Predicate::Call")
|
||||
}
|
||||
}
|
||||
Predicate::GeneralEqual { lhs, rhs } => {
|
||||
match (self.eval_pred(*lhs)?, self.eval_pred(*rhs)?) {
|
||||
(Predicate::Value(lhs), Predicate::Value(rhs)) => {
|
||||
Ok(Predicate::Value(ValueObj::Bool(lhs == rhs)))
|
||||
}
|
||||
(lhs, rhs) => Ok(Predicate::general_eq(lhs, rhs)),
|
||||
}
|
||||
}
|
||||
Predicate::Equal { lhs, rhs } => Ok(Predicate::eq(lhs, self.eval_tp(rhs)?)),
|
||||
Predicate::NotEqual { lhs, rhs } => Ok(Predicate::ne(lhs, self.eval_tp(rhs)?)),
|
||||
Predicate::LessEqual { lhs, rhs } => Ok(Predicate::le(lhs, self.eval_tp(rhs)?)),
|
||||
|
|
|
@ -323,7 +323,24 @@ impl Generalizer {
|
|||
*typ.typ_mut() = self.generalize_t(mem::take(typ.typ_mut()), uninit);
|
||||
Predicate::Value(ValueObj::Type(typ))
|
||||
}
|
||||
Predicate::Call {
|
||||
receiver,
|
||||
name,
|
||||
args,
|
||||
} => {
|
||||
let receiver = self.generalize_tp(receiver, uninit);
|
||||
let mut new_args = vec![];
|
||||
for arg in args.into_iter() {
|
||||
new_args.push(self.generalize_tp(arg, uninit));
|
||||
}
|
||||
Predicate::call(receiver, name, new_args)
|
||||
}
|
||||
Predicate::Value(_) => pred,
|
||||
Predicate::GeneralEqual { lhs, rhs } => {
|
||||
let lhs = self.generalize_pred(*lhs, uninit);
|
||||
let rhs = self.generalize_pred(*rhs, uninit);
|
||||
Predicate::general_eq(lhs, rhs)
|
||||
}
|
||||
Predicate::Equal { lhs, rhs } => {
|
||||
let rhs = self.generalize_tp(rhs, uninit);
|
||||
Predicate::eq(lhs, rhs)
|
||||
|
@ -414,6 +431,11 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
|
|||
TyParam::FreeVar(_) if self.ctx.level == 0 => {
|
||||
Ok(TyParam::erased(self.ctx.get_tp_t(&tp).unwrap_or(Type::Obj)))
|
||||
}
|
||||
TyParam::FreeVar(fv) if fv.get_type().is_some() => {
|
||||
let t = self.deref_tyvar(fv.get_type().unwrap())?;
|
||||
fv.update_type(t);
|
||||
Ok(TyParam::FreeVar(fv))
|
||||
}
|
||||
TyParam::Type(t) => Ok(TyParam::t(self.deref_tyvar(*t)?)),
|
||||
TyParam::Erased(t) => Ok(TyParam::erased(self.deref_tyvar(*t)?)),
|
||||
TyParam::App { name, mut args } => {
|
||||
|
@ -539,6 +561,78 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
|
|||
}
|
||||
}
|
||||
|
||||
fn deref_pred(&mut self, pred: Predicate) -> TyCheckResult<Predicate> {
|
||||
match pred {
|
||||
Predicate::Equal { lhs, rhs } => {
|
||||
let rhs = self.deref_tp(rhs)?;
|
||||
Ok(Predicate::eq(lhs, rhs))
|
||||
}
|
||||
Predicate::GreaterEqual { lhs, rhs } => {
|
||||
let rhs = self.deref_tp(rhs)?;
|
||||
Ok(Predicate::ge(lhs, rhs))
|
||||
}
|
||||
Predicate::LessEqual { lhs, rhs } => {
|
||||
let rhs = self.deref_tp(rhs)?;
|
||||
Ok(Predicate::le(lhs, rhs))
|
||||
}
|
||||
Predicate::NotEqual { lhs, rhs } => {
|
||||
let rhs = self.deref_tp(rhs)?;
|
||||
Ok(Predicate::ne(lhs, rhs))
|
||||
}
|
||||
Predicate::GeneralEqual { lhs, rhs } => {
|
||||
let lhs = self.deref_pred(*lhs)?;
|
||||
let rhs = self.deref_pred(*rhs)?;
|
||||
match (lhs, rhs) {
|
||||
(Predicate::Value(lhs), Predicate::Value(rhs)) => {
|
||||
Ok(Predicate::Value(ValueObj::Bool(lhs == rhs)))
|
||||
}
|
||||
(lhs, rhs) => Ok(Predicate::general_eq(lhs, rhs)),
|
||||
}
|
||||
}
|
||||
Predicate::Call {
|
||||
receiver,
|
||||
name,
|
||||
args,
|
||||
} => {
|
||||
let Ok(receiver) = self.deref_tp(receiver.clone()) else {
|
||||
return Ok(Predicate::call(receiver, name, args));
|
||||
};
|
||||
let mut new_args = vec![];
|
||||
for arg in args.into_iter() {
|
||||
let Ok(arg) = self.deref_tp(arg) else {
|
||||
return Ok(Predicate::call(receiver, name, new_args));
|
||||
};
|
||||
new_args.push(arg);
|
||||
}
|
||||
let evaled = if let Some(name) = &name {
|
||||
self.ctx
|
||||
.eval_proj_call(receiver.clone(), name.clone(), new_args.clone(), &())
|
||||
} else {
|
||||
return Ok(Predicate::call(receiver, name, new_args));
|
||||
};
|
||||
match evaled {
|
||||
Ok(TyParam::Value(value)) => Ok(Predicate::Value(value)),
|
||||
_ => Ok(Predicate::call(receiver, name, new_args)),
|
||||
}
|
||||
}
|
||||
Predicate::And(lhs, rhs) => {
|
||||
let lhs = self.deref_pred(*lhs)?;
|
||||
let rhs = self.deref_pred(*rhs)?;
|
||||
Ok(Predicate::and(lhs, rhs))
|
||||
}
|
||||
Predicate::Or(lhs, rhs) => {
|
||||
let lhs = self.deref_pred(*lhs)?;
|
||||
let rhs = self.deref_pred(*rhs)?;
|
||||
Ok(Predicate::or(lhs, rhs))
|
||||
}
|
||||
Predicate::Not(pred) => {
|
||||
let pred = self.deref_pred(*pred)?;
|
||||
Ok(!pred)
|
||||
}
|
||||
_ => Ok(pred),
|
||||
}
|
||||
}
|
||||
|
||||
fn deref_constraint(&mut self, constraint: Constraint) -> TyCheckResult<Constraint> {
|
||||
match constraint {
|
||||
Constraint::Sandwiched { sub, sup } => Ok(Constraint::new_sandwiched(
|
||||
|
@ -741,9 +835,10 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
|
|||
Ok(Type::NamedTuple(rec))
|
||||
}
|
||||
Type::Refinement(refine) => {
|
||||
log!(err "deref_tyvar: {} / {}", refine.t, refine.pred);
|
||||
let t = self.deref_tyvar(*refine.t)?;
|
||||
// TODO: deref_predicate
|
||||
Ok(refinement(refine.var, t, *refine.pred))
|
||||
let pred = self.deref_pred(*refine.pred)?;
|
||||
Ok(refinement(refine.var, t, pred))
|
||||
}
|
||||
Type::And(l, r) => {
|
||||
let l = self.deref_tyvar(*l)?;
|
||||
|
@ -797,6 +892,12 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
|
|||
fn validate_subsup(&mut self, sub_t: Type, super_t: Type) -> TyCheckResult<Type> {
|
||||
// TODO: Subr, ...
|
||||
match (sub_t, super_t) {
|
||||
/*(sub_t @ Type::Refinement(_), super_t @ Type::Refinement(_)) => {
|
||||
self.validate_simple_subsup(sub_t, super_t)
|
||||
}
|
||||
(Type::Refinement(refine), super_t) => {
|
||||
self.validate_simple_subsup(*refine.t, super_t)
|
||||
}*/
|
||||
// See tests\should_err\subtyping.er:8~13
|
||||
(
|
||||
Type::Poly {
|
||||
|
|
|
@ -1611,6 +1611,32 @@ impl Context {
|
|||
array_t(T.clone(), TyParam::erased(Nat)),
|
||||
);
|
||||
array_.register_py_builtin(FUNC_DEDUP, t.quantify(), Some(FUNC_DEDUP), 28);
|
||||
let sum_t = no_var_fn_met(
|
||||
array_t(T.clone(), TyParam::erased(Nat)),
|
||||
vec![],
|
||||
vec![kw("start", T.clone())],
|
||||
T.clone(),
|
||||
);
|
||||
let sum = ValueObj::Subr(ConstSubr::Builtin(BuiltinConstSubr::new(
|
||||
FUNC_SUM,
|
||||
array_sum,
|
||||
sum_t.quantify(),
|
||||
None,
|
||||
)));
|
||||
array_.register_builtin_const(FUNC_SUM, Visibility::BUILTIN_PUBLIC, sum);
|
||||
let prod_t = no_var_fn_met(
|
||||
array_t(T.clone(), TyParam::erased(Nat)),
|
||||
vec![],
|
||||
vec![kw("start", T.clone())],
|
||||
T.clone(),
|
||||
);
|
||||
let prod = ValueObj::Subr(ConstSubr::Builtin(BuiltinConstSubr::new(
|
||||
FUNC_PROD,
|
||||
array_prod,
|
||||
prod_t.quantify(),
|
||||
None,
|
||||
)));
|
||||
array_.register_builtin_const(FUNC_PROD, Visibility::BUILTIN_PUBLIC, prod);
|
||||
/* Slice */
|
||||
let mut slice = Self::builtin_mono_class(SLICE, 3);
|
||||
slice.register_superclass(Obj, &obj);
|
||||
|
|
|
@ -519,6 +519,106 @@ pub(crate) fn array_shape(mut args: ValueArgs, ctx: &Context) -> EvalValueResult
|
|||
Ok(arr)
|
||||
}
|
||||
|
||||
fn _array_sum(arr: ValueObj, _ctx: &Context) -> Result<ValueObj, String> {
|
||||
match arr {
|
||||
ValueObj::Array(a) => {
|
||||
let mut sum = 0f64;
|
||||
for v in a.iter() {
|
||||
match v {
|
||||
ValueObj::Nat(n) => {
|
||||
sum += *n as f64;
|
||||
}
|
||||
ValueObj::Int(n) => {
|
||||
sum += *n as f64;
|
||||
}
|
||||
ValueObj::Float(n) => {
|
||||
sum += *n;
|
||||
}
|
||||
ValueObj::Inf => {
|
||||
return Ok(ValueObj::Inf);
|
||||
}
|
||||
ValueObj::NegInf => {
|
||||
return Ok(ValueObj::NegInf);
|
||||
}
|
||||
_ => {
|
||||
return Err(format!("Cannot sum {v}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
if sum.round() == sum && sum >= 0.0 {
|
||||
Ok(ValueObj::Nat(sum as u64))
|
||||
} else if sum.round() == sum {
|
||||
Ok(ValueObj::Int(sum as i32))
|
||||
} else {
|
||||
Ok(ValueObj::Float(sum))
|
||||
}
|
||||
}
|
||||
_ => Err(format!("Cannot sum {arr}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// ```erg
|
||||
/// [1, 2].sum() == [3,]
|
||||
/// ```
|
||||
pub(crate) fn array_sum(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
|
||||
let arr = args
|
||||
.remove_left_or_key("Self")
|
||||
.ok_or_else(|| not_passed("Self"))?;
|
||||
let res = _array_sum(arr, ctx).unwrap();
|
||||
let arr = TyParam::Value(res);
|
||||
Ok(arr)
|
||||
}
|
||||
|
||||
fn _array_prod(arr: ValueObj, _ctx: &Context) -> Result<ValueObj, String> {
|
||||
match arr {
|
||||
ValueObj::Array(a) => {
|
||||
let mut prod = 1f64;
|
||||
for v in a.iter() {
|
||||
match v {
|
||||
ValueObj::Nat(n) => {
|
||||
prod *= *n as f64;
|
||||
}
|
||||
ValueObj::Int(n) => {
|
||||
prod *= *n as f64;
|
||||
}
|
||||
ValueObj::Float(n) => {
|
||||
prod *= *n;
|
||||
}
|
||||
ValueObj::Inf => {
|
||||
return Ok(ValueObj::Inf);
|
||||
}
|
||||
ValueObj::NegInf => {
|
||||
return Ok(ValueObj::NegInf);
|
||||
}
|
||||
_ => {
|
||||
return Err(format!("Cannot prod {v}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
if prod.round() == prod && prod >= 0.0 {
|
||||
Ok(ValueObj::Nat(prod as u64))
|
||||
} else if prod.round() == prod {
|
||||
Ok(ValueObj::Int(prod as i32))
|
||||
} else {
|
||||
Ok(ValueObj::Float(prod))
|
||||
}
|
||||
}
|
||||
_ => Err(format!("Cannot prod {arr}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// ```erg
|
||||
/// [1, 2].prod() == [2,]
|
||||
/// ```
|
||||
pub(crate) fn array_prod(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
|
||||
let arr = args
|
||||
.remove_left_or_key("Self")
|
||||
.ok_or_else(|| not_passed("Self"))?;
|
||||
let res = _array_prod(arr, ctx).unwrap();
|
||||
let arr = TyParam::Value(res);
|
||||
Ok(arr)
|
||||
}
|
||||
|
||||
pub(crate) fn __range_getitem__(mut args: ValueArgs, _ctx: &Context) -> EvalValueResult<TyParam> {
|
||||
let slf = args
|
||||
.remove_left_or_key("Self")
|
||||
|
|
|
@ -419,6 +419,7 @@ const FUNC_DELATTR: &str = "delattr";
|
|||
const FUNC_NEARLY_EQ: &str = "nearly_eq";
|
||||
const FUNC_RESOLVE_PATH: &str = "ResolvePath";
|
||||
const FUNC_RESOLVE_DECL_PATH: &str = "ResolveDeclPath";
|
||||
const FUNC_PROD: &str = "prod";
|
||||
|
||||
const OP_EQ: &str = "__eq__";
|
||||
const OP_HASH: &str = "__hash__";
|
||||
|
|
|
@ -137,10 +137,6 @@ impl TyVarCache {
|
|||
}
|
||||
}
|
||||
|
||||
fn _instantiate_pred(&self, _pred: Predicate) -> Predicate {
|
||||
todo!()
|
||||
}
|
||||
|
||||
/// Some of the quantified types are circulating.
|
||||
/// e.g.
|
||||
/// ```erg
|
||||
|
@ -576,6 +572,23 @@ impl Context {
|
|||
let value = self.instantiate_value(value, tmp_tv_cache, loc)?;
|
||||
Ok(Predicate::Value(value))
|
||||
}
|
||||
Predicate::Call {
|
||||
receiver,
|
||||
name,
|
||||
args,
|
||||
} => {
|
||||
let receiver = self.instantiate_tp(receiver, tmp_tv_cache, loc)?;
|
||||
let mut new_args = Vec::with_capacity(args.len());
|
||||
for arg in args {
|
||||
new_args.push(self.instantiate_tp(arg, tmp_tv_cache, loc)?);
|
||||
}
|
||||
Ok(Predicate::call(receiver, name, new_args))
|
||||
}
|
||||
Predicate::GeneralEqual { lhs, rhs } => {
|
||||
let lhs = self.instantiate_pred(*lhs, tmp_tv_cache, loc)?;
|
||||
let rhs = self.instantiate_pred(*rhs, tmp_tv_cache, loc)?;
|
||||
Ok(Predicate::general_eq(lhs, rhs))
|
||||
}
|
||||
_ => Ok(pred),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1462,6 +1462,20 @@ impl Context {
|
|||
self.inc_ref_local(local, self, tmp_tv_cache);
|
||||
Ok(Predicate::Const(local.inspect().clone()))
|
||||
}
|
||||
ast::ConstExpr::App(app) => {
|
||||
let receiver = self.instantiate_const_expr(&app.obj, None, tmp_tv_cache, false)?;
|
||||
let name = app.attr_name.as_ref().map(|n| n.inspect().to_owned());
|
||||
let mut args = vec![];
|
||||
for arg in app.args.pos_args() {
|
||||
let arg = self.instantiate_const_expr(&arg.expr, None, tmp_tv_cache, false)?;
|
||||
args.push(arg);
|
||||
}
|
||||
Ok(Predicate::Call {
|
||||
receiver,
|
||||
name,
|
||||
args,
|
||||
})
|
||||
}
|
||||
ast::ConstExpr::BinOp(bin) => {
|
||||
let lhs = self.instantiate_pred_from_expr(&bin.lhs, tmp_tv_cache)?;
|
||||
let rhs = self.instantiate_pred_from_expr(&bin.rhs, tmp_tv_cache)?;
|
||||
|
@ -1472,16 +1486,25 @@ impl Context {
|
|||
| TokenKind::LessEq
|
||||
| TokenKind::Gre
|
||||
| TokenKind::GreEq => {
|
||||
let Predicate::Const(var) = lhs else {
|
||||
return type_feature_error!(
|
||||
self,
|
||||
bin.loc(),
|
||||
&format!("instantiating predicate `{expr}`")
|
||||
);
|
||||
let var = match lhs {
|
||||
Predicate::Const(var) => var,
|
||||
other if bin.op.kind == TokenKind::DblEq => {
|
||||
return Ok(Predicate::general_eq(other, rhs));
|
||||
}
|
||||
_ => {
|
||||
return type_feature_error!(
|
||||
self,
|
||||
bin.loc(),
|
||||
&format!("instantiating predicate `{expr}`")
|
||||
);
|
||||
}
|
||||
};
|
||||
let rhs = match rhs {
|
||||
Predicate::Value(value) => TyParam::Value(value),
|
||||
Predicate::Const(var) => TyParam::Mono(var),
|
||||
other if bin.op.kind == TokenKind::DblEq => {
|
||||
return Ok(Predicate::general_eq(Predicate::Const(var), other));
|
||||
}
|
||||
_ => {
|
||||
return type_feature_error!(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue