fix: type generalization bug

This commit is contained in:
Shunsuke Shibayama 2024-05-15 22:01:31 +09:00
parent dc1e32f5f4
commit dc7565cb26
6 changed files with 70 additions and 34 deletions

View file

@ -8,6 +8,7 @@ use erg_common::{dict, fn_name, get_hash, set};
#[allow(unused_imports)]
use erg_common::{fmt_vec, log};
use crate::hir::GuardClause;
use crate::module::GeneralizationResult;
use crate::ty::constructors::*;
use crate::ty::free::{CanbeFree, Constraint, Free, HasLevel};
@ -1425,6 +1426,23 @@ impl Context {
param.sig.vi.t = dereferencer.deref_tyvar(t)?;
self.resolve_expr_t(&mut param.default_val, qnames)?;
}
if let Some(kw_var) = &mut params.kw_var_params {
kw_var.vi.t.generalize();
let t = mem::take(&mut kw_var.vi.t);
let mut dereferencer =
Dereferencer::new(self, Contravariant, false, qnames, kw_var.as_ref());
kw_var.vi.t = dereferencer.deref_tyvar(t)?;
}
for guard in params.guards.iter_mut() {
match guard {
GuardClause::Bind(def) => {
self.resolve_def_t(def, qnames)?;
}
GuardClause::Condition(cond) => {
self.resolve_expr_t(cond, qnames)?;
}
}
}
Ok(())
}
@ -1435,17 +1453,9 @@ impl Context {
match expr {
hir::Expr::Literal(_) => Ok(()),
hir::Expr::Accessor(acc) => {
if acc
.ref_t()
.unbound_name()
.map_or(false, |name| !qnames.contains(&name))
{
let t = mem::take(acc.ref_mut_t().unwrap());
let mut dereferencer = Dereferencer::simple(self, qnames, acc);
*acc.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
} else {
acc.ref_mut_t().unwrap().dereference();
}
let t = mem::take(acc.ref_mut_t().unwrap());
let mut dereferencer = Dereferencer::simple(self, qnames, acc);
*acc.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
if let hir::Accessor::Attr(attr) = acc {
self.resolve_expr_t(&mut attr.obj, qnames)?;
}
@ -1559,7 +1569,6 @@ impl Context {
Ok(())
}
hir::Expr::Call(call) => {
self.resolve_expr_t(&mut call.obj, qnames)?;
for arg in call.args.pos_args.iter_mut() {
self.resolve_expr_t(&mut arg.expr, qnames)?;
}
@ -1569,6 +1578,10 @@ impl Context {
for arg in call.args.kw_args.iter_mut() {
self.resolve_expr_t(&mut arg.expr, qnames)?;
}
if let Some(kw_var) = &mut call.args.kw_var {
self.resolve_expr_t(&mut kw_var.expr, qnames)?;
}
self.resolve_expr_t(&mut call.obj, qnames)?;
if let Some(t) = call.signature_mut_t() {
let t = mem::take(t);
let mut dereferencer = Dereferencer::simple(self, qnames, call);
@ -1576,27 +1589,7 @@ impl Context {
}
Ok(())
}
hir::Expr::Def(def) => {
let qnames = if let Type::Quantified(quant) = def.sig.ref_t() {
// double quantification is not allowed
let Ok(subr) = <&SubrType>::try_from(quant.as_ref()) else {
unreachable!()
};
subr.essential_qnames()
} else {
qnames.clone()
};
let t = mem::take(def.sig.ref_mut_t().unwrap());
let mut dereferencer = Dereferencer::simple(self, &qnames, &def.sig);
*def.sig.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
if let Some(params) = def.sig.params_mut() {
self.resolve_params_t(params, &qnames)?;
}
for chunk in def.body.block.iter_mut() {
self.resolve_expr_t(chunk, &qnames)?;
}
Ok(())
}
hir::Expr::Def(def) => self.resolve_def_t(def, qnames),
hir::Expr::Lambda(lambda) => {
let qnames = if let Type::Quantified(quant) = lambda.ref_t() {
let Ok(subr) = <&SubrType>::try_from(quant.as_ref()) else {
@ -1663,6 +1656,28 @@ impl Context {
}
}
fn resolve_def_t(&self, def: &mut hir::Def, qnames: &Set<Str>) -> TyCheckResult<()> {
let qnames = if let Type::Quantified(quant) = def.sig.ref_t() {
// double quantification is not allowed
let Ok(subr) = <&SubrType>::try_from(quant.as_ref()) else {
unreachable!()
};
subr.essential_qnames()
} else {
qnames.clone()
};
let t = mem::take(def.sig.ref_mut_t().unwrap());
let mut dereferencer = Dereferencer::simple(self, &qnames, &def.sig);
*def.sig.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
if let Some(params) = def.sig.params_mut() {
self.resolve_params_t(params, &qnames)?;
}
for chunk in def.body.block.iter_mut() {
self.resolve_expr_t(chunk, &qnames)?;
}
Ok(())
}
/// ```erg
/// squash_tyvar(?1 or ?2) == ?1(== ?2)
/// squash_tyvar(?T or ?U) == ?T or ?U

View file

@ -1526,6 +1526,8 @@ pub(crate) fn zip_func(mut args: ValueArgs, _ctx: &Context) -> EvalValueResult<T
/// ```erg
/// derefine({X: T | ...}) == T
/// derefine({1}) == Nat
/// derefine(List!({1, 2}, 2)) == List!(Nat, 2)
/// ```
pub(crate) fn derefine_func(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
let val = args