fix: type generalization bug

This commit is contained in:
Shunsuke Shibayama 2024-05-15 22:01:31 +09:00
parent dc1e32f5f4
commit dc7565cb26
6 changed files with 70 additions and 34 deletions

View file

@ -8,6 +8,7 @@ use erg_common::{dict, fn_name, get_hash, set};
#[allow(unused_imports)]
use erg_common::{fmt_vec, log};
use crate::hir::GuardClause;
use crate::module::GeneralizationResult;
use crate::ty::constructors::*;
use crate::ty::free::{CanbeFree, Constraint, Free, HasLevel};
@ -1425,6 +1426,23 @@ impl Context {
param.sig.vi.t = dereferencer.deref_tyvar(t)?;
self.resolve_expr_t(&mut param.default_val, qnames)?;
}
if let Some(kw_var) = &mut params.kw_var_params {
kw_var.vi.t.generalize();
let t = mem::take(&mut kw_var.vi.t);
let mut dereferencer =
Dereferencer::new(self, Contravariant, false, qnames, kw_var.as_ref());
kw_var.vi.t = dereferencer.deref_tyvar(t)?;
}
for guard in params.guards.iter_mut() {
match guard {
GuardClause::Bind(def) => {
self.resolve_def_t(def, qnames)?;
}
GuardClause::Condition(cond) => {
self.resolve_expr_t(cond, qnames)?;
}
}
}
Ok(())
}
@ -1435,17 +1453,9 @@ impl Context {
match expr {
hir::Expr::Literal(_) => Ok(()),
hir::Expr::Accessor(acc) => {
if acc
.ref_t()
.unbound_name()
.map_or(false, |name| !qnames.contains(&name))
{
let t = mem::take(acc.ref_mut_t().unwrap());
let mut dereferencer = Dereferencer::simple(self, qnames, acc);
*acc.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
} else {
acc.ref_mut_t().unwrap().dereference();
}
let t = mem::take(acc.ref_mut_t().unwrap());
let mut dereferencer = Dereferencer::simple(self, qnames, acc);
*acc.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
if let hir::Accessor::Attr(attr) = acc {
self.resolve_expr_t(&mut attr.obj, qnames)?;
}
@ -1559,7 +1569,6 @@ impl Context {
Ok(())
}
hir::Expr::Call(call) => {
self.resolve_expr_t(&mut call.obj, qnames)?;
for arg in call.args.pos_args.iter_mut() {
self.resolve_expr_t(&mut arg.expr, qnames)?;
}
@ -1569,6 +1578,10 @@ impl Context {
for arg in call.args.kw_args.iter_mut() {
self.resolve_expr_t(&mut arg.expr, qnames)?;
}
if let Some(kw_var) = &mut call.args.kw_var {
self.resolve_expr_t(&mut kw_var.expr, qnames)?;
}
self.resolve_expr_t(&mut call.obj, qnames)?;
if let Some(t) = call.signature_mut_t() {
let t = mem::take(t);
let mut dereferencer = Dereferencer::simple(self, qnames, call);
@ -1576,27 +1589,7 @@ impl Context {
}
Ok(())
}
hir::Expr::Def(def) => {
let qnames = if let Type::Quantified(quant) = def.sig.ref_t() {
// double quantification is not allowed
let Ok(subr) = <&SubrType>::try_from(quant.as_ref()) else {
unreachable!()
};
subr.essential_qnames()
} else {
qnames.clone()
};
let t = mem::take(def.sig.ref_mut_t().unwrap());
let mut dereferencer = Dereferencer::simple(self, &qnames, &def.sig);
*def.sig.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
if let Some(params) = def.sig.params_mut() {
self.resolve_params_t(params, &qnames)?;
}
for chunk in def.body.block.iter_mut() {
self.resolve_expr_t(chunk, &qnames)?;
}
Ok(())
}
hir::Expr::Def(def) => self.resolve_def_t(def, qnames),
hir::Expr::Lambda(lambda) => {
let qnames = if let Type::Quantified(quant) = lambda.ref_t() {
let Ok(subr) = <&SubrType>::try_from(quant.as_ref()) else {
@ -1663,6 +1656,28 @@ impl Context {
}
}
fn resolve_def_t(&self, def: &mut hir::Def, qnames: &Set<Str>) -> TyCheckResult<()> {
let qnames = if let Type::Quantified(quant) = def.sig.ref_t() {
// double quantification is not allowed
let Ok(subr) = <&SubrType>::try_from(quant.as_ref()) else {
unreachable!()
};
subr.essential_qnames()
} else {
qnames.clone()
};
let t = mem::take(def.sig.ref_mut_t().unwrap());
let mut dereferencer = Dereferencer::simple(self, &qnames, &def.sig);
*def.sig.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
if let Some(params) = def.sig.params_mut() {
self.resolve_params_t(params, &qnames)?;
}
for chunk in def.body.block.iter_mut() {
self.resolve_expr_t(chunk, &qnames)?;
}
Ok(())
}
/// ```erg
/// squash_tyvar(?1 or ?2) == ?1(== ?2)
/// squash_tyvar(?T or ?U) == ?T or ?U

View file

@ -1526,6 +1526,8 @@ pub(crate) fn zip_func(mut args: ValueArgs, _ctx: &Context) -> EvalValueResult<T
/// ```erg
/// derefine({X: T | ...}) == T
/// derefine({1}) == Nat
/// derefine(List!({1, 2}, 2)) == List!(Nat, 2)
/// ```
pub(crate) fn derefine_func(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
let val = args

View file

@ -1,6 +1,8 @@
.abc = pyimport "abc"
.NamedTuple = 'namedtuple': ClassType
.NamedTuple.
__call__: (typename: Str, field_names: Sequence(Str), rename := Bool) -> (*Obj, **Obj) -> NamedTuple
.Deque = 'deque': ClassType
.ChainMap: ClassType
.Counter: ClassType

View file

@ -1663,6 +1663,7 @@ impl Not for Type {
fn get_t_from_tp(tp: &TyParam) -> Option<Type> {
match tp {
TyParam::FreeVar(fv) if fv.is_linked() => get_t_from_tp(&fv.crack()),
TyParam::Value(ValueObj::Type(t)) => Some(t.typ().clone()),
TyParam::Type(t) => Some(*t.clone()),
_ => None,
}
@ -2229,7 +2230,7 @@ impl Type {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_nonelike(),
Self::NoneType => true,
Self::Poly { name, params, .. } if &name[..] == "Option" || &name[..] == "Option!" => {
let Some(TyParam::Type(inner_t)) = params.first() else {
let Some(inner_t) = params.first().and_then(|tp| <&Type>::try_from(tp).ok()) else {
return false;
};
inner_t.is_nonelike()
@ -3647,6 +3648,9 @@ impl Type {
let params = params
.iter()
.map(|tp| match tp {
TyParam::Value(ValueObj::Type(t)) => {
TyParam::Value(ValueObj::Type(t.clone().mapped_t(|t| t.derefine())))
}
TyParam::Type(t) => TyParam::t(t.derefine()),
other => other.clone(),
})
@ -3679,12 +3683,18 @@ impl Type {
args,
} => {
let lhs = match lhs.as_ref() {
TyParam::Value(ValueObj::Type(t)) => {
TyParam::Value(ValueObj::Type(t.clone().mapped_t(|t| t.derefine())))
}
TyParam::Type(t) => TyParam::t(t.derefine()),
other => other.clone(),
};
let args = args
.iter()
.map(|arg| match arg {
TyParam::Value(ValueObj::Type(t)) => {
TyParam::Value(ValueObj::Type(t.clone().mapped_t(|t| t.derefine())))
}
TyParam::Type(t) => TyParam::t(t.derefine()),
other => other.clone(),
})

View file

@ -676,6 +676,7 @@ impl<'t> TryFrom<&'t TyParam> for &'t FreeTyVar {
fn try_from(t: &'t TyParam) -> Result<&'t FreeTyVar, ()> {
match t {
TyParam::Type(ty) => <&FreeTyVar>::try_from(ty.as_ref()),
TyParam::Value(ValueObj::Type(ty)) => <&FreeTyVar>::try_from(ty.typ()),
_ => Err(()),
}
}
@ -1167,6 +1168,7 @@ impl TyParam {
typ.undoable_coerce(list);
}
TyParam::Type(t) => t.undoable_coerce(list),
TyParam::Value(ValueObj::Type(t)) => t.typ().undoable_coerce(list),
// TODO:
_ => {}
}

View file

@ -488,6 +488,11 @@ impl TypeObj {
}
}
pub fn mapped_t(mut self, f: impl FnOnce(Type) -> Type) -> Self {
self.map_t(f);
self
}
pub fn try_map_t<E>(&mut self, f: impl FnOnce(Type) -> Result<Type, E>) -> Result<(), E> {
match self {
TypeObj::Builtin { t, .. } => {