fix: forward-referenced method inference bug

This commit is contained in:
Shunsuke Shibayama 2023-10-29 21:24:40 +09:00
parent 34a20e7005
commit 6713ffeaae
8 changed files with 194 additions and 69 deletions

View file

@ -199,6 +199,21 @@ impl Context {
.map(|(opt_name, vi)| (opt_name.as_ref().unwrap(), vi))
}
pub fn get_method_kv(&self, name: &str) -> Option<(&VarName, &VarInfo)> {
#[cfg(feature = "py_compat")]
let name = self.erg_to_py_names.get(name).map_or(name, |s| &s[..]);
self.get_var_kv(name)
.or_else(|| {
for methods in self.methods_list.iter() {
if let Some(vi) = methods.get_method_kv(name) {
return Some(vi);
}
}
None
})
.or_else(|| self.get_outer().and_then(|ctx| ctx.get_method_kv(name)))
}
pub fn get_singular_ctxs_by_hir_expr(
&self,
obj: &hir::Expr,

View file

@ -741,6 +741,91 @@ impl Context {
}
}
fn unify_params_t(
&self,
sig: &ast::SubrSignature,
registered_t: &SubrType,
params: &hir::Params,
body_t: &Type,
body_loc: &impl Locational,
) -> TyCheckResult<()> {
let name = &sig.ident.name;
let mut errs = TyCheckErrors::empty();
for (param, pt) in params
.non_defaults
.iter()
.zip(registered_t.non_default_params.iter())
{
pt.typ().lower();
if let Err(es) = self.force_sub_unify(&param.vi.t, pt.typ(), param, None) {
errs.extend(es);
}
pt.typ().lift();
}
// TODO: var_params: [Int; _], pt: Int
/*if let Some((var_params, pt)) = params.var_params.as_deref().zip(registered_t.var_params.as_ref()) {
pt.typ().lower();
if let Err(es) = self.force_sub_unify(&var_params.vi.t, pt.typ(), var_params, None) {
errs.extend(es);
}
pt.typ().lift();
}*/
for (param, pt) in params
.defaults
.iter()
.zip(registered_t.default_params.iter())
{
pt.typ().lower();
if let Err(es) = self.force_sub_unify(&param.sig.vi.t, pt.typ(), param, None) {
errs.extend(es);
}
pt.typ().lift();
}
let spec_ret_t = registered_t.return_t.as_ref();
// spec_ret_t.lower();
let unify_return_result = if let Some(t_spec) = sig.return_t_spec.as_ref() {
self.force_sub_unify(body_t, spec_ret_t, t_spec, None)
} else {
self.force_sub_unify(body_t, spec_ret_t, body_loc, None)
};
// spec_ret_t.lift();
if let Err(unify_errs) = unify_return_result {
let es = TyCheckErrors::new(
unify_errs
.into_iter()
.map(|e| {
let expect = if cfg!(feature = "debug") {
spec_ret_t.clone()
} else {
self.readable_type(spec_ret_t.clone())
};
let found = if cfg!(feature = "debug") {
body_t.clone()
} else {
self.readable_type(body_t.clone())
};
TyCheckError::return_type_error(
self.cfg.input.clone(),
line!() as usize,
e.core.get_loc_with_fallback(),
e.caused_by,
readable_name(name.inspect()),
&expect,
&found,
// e.core.get_hint().map(|s| s.to_string()),
)
})
.collect(),
);
errs.extend(es);
}
if errs.is_empty() {
Ok(())
} else {
Err(errs)
}
}
/// ## Errors
/// * TypeError: if `return_t` != typeof `body`
/// * AssignError: if `name` has already been registered
@ -748,6 +833,7 @@ impl Context {
&mut self,
sig: &ast::SubrSignature,
id: DefId,
params: &hir::Params,
body_t: &Type,
body_loc: &impl Locational,
) -> Result<VarInfo, (TyCheckErrors, VarInfo)> {
@ -772,63 +858,27 @@ impl Context {
};
let name = &sig.ident.name;
// FIXME: constでない関数
let t = self.get_current_scope_var(name).map(|vi| &vi.t).unwrap();
debug_assert!(t.is_subr(), "{t} is not subr");
let empty = vec![];
let non_default_params = t.non_default_params().unwrap_or(&empty);
let var_args = t.var_params();
let default_params = t.default_params().unwrap_or(&empty);
if let Some(spec_ret_t) = t.return_t() {
let unify_result = if let Some(t_spec) = sig.return_t_spec.as_ref() {
self.sub_unify(body_t, spec_ret_t, t_spec, None)
} else {
self.sub_unify(body_t, spec_ret_t, body_loc, None)
};
if let Err(unify_errs) = unify_result {
let es = TyCheckErrors::new(
unify_errs
.into_iter()
.map(|e| {
let expect = if cfg!(feature = "debug") {
spec_ret_t.clone()
} else {
self.readable_type(spec_ret_t.clone())
};
let found = if cfg!(feature = "debug") {
body_t.clone()
} else {
self.readable_type(body_t.clone())
};
TyCheckError::return_type_error(
self.cfg.input.clone(),
line!() as usize,
e.core.get_loc_with_fallback(),
e.caused_by,
readable_name(name.inspect()),
&expect,
&found,
// e.core.get_hint().map(|s| s.to_string()),
)
})
.collect(),
);
errs.extend(es);
}
let subr_t = self.get_current_scope_var(name).map(|vi| &vi.t).unwrap();
let Ok(subr_t) = <&SubrType>::try_from(subr_t) else {
panic!("{subr_t} is not subr");
};
if let Err(es) = self.unify_params_t(sig, subr_t, params, body_t, body_loc) {
errs.extend(es);
}
// NOTE: not `body_t.clone()` because the body may contain `return`
let return_t = t.return_t().unwrap().clone();
let return_t = subr_t.return_t.as_ref().clone();
let sub_t = if sig.ident.is_procedural() {
proc(
non_default_params.clone(),
var_args.cloned(),
default_params.clone(),
subr_t.non_default_params.clone(),
subr_t.var_params.as_deref().cloned(),
subr_t.default_params.clone(),
return_t,
)
} else {
func(
non_default_params.clone(),
var_args.cloned(),
default_params.clone(),
subr_t.non_default_params.clone(),
subr_t.var_params.as_deref().cloned(),
subr_t.default_params.clone(),
return_t,
)
};

View file

@ -24,6 +24,22 @@ impl Context {
}
}
pub fn assert_attr_type(&self, receiver_t: &Type, attr: &str, ty: &Type) -> Result<(), ()> {
let Some(ctx) = self.get_nominal_type_ctx(receiver_t) else {
panic!("type not found: {receiver_t}");
};
let Some((_, vi)) = ctx.get_method_kv(attr) else {
panic!("attribute not found: {attr}");
};
println!("{attr}: {}", vi.t);
if vi.t.structural_eq(ty) {
Ok(())
} else {
println!("{attr} is not the type of {ty}");
Err(())
}
}
pub fn test_refinement_subtyping(&self) -> Result<(), ()> {
// Nat :> {I: Int | I >= 1} ?
let lhs = Nat;

View file

@ -31,6 +31,7 @@ pub struct Unifier<'c, 'l, 'u, L: Locational> {
ctx: &'c Context,
loc: &'l L,
undoable: Option<&'u UndoableLinkedList>,
change_generalized: bool,
param_name: Option<Str>,
}
@ -39,12 +40,14 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
ctx: &'c Context,
loc: &'l L,
undoable: Option<&'u UndoableLinkedList>,
change_generalized: bool,
param_name: Option<Str>,
) -> Self {
Self {
ctx,
loc,
undoable,
change_generalized,
param_name,
}
}
@ -326,7 +329,11 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
Ok(())
}
(TyParam::FreeVar(sub_fv), _) if sub_fv.is_generalized() => Ok(()),
(TyParam::FreeVar(sub_fv), _)
if !self.change_generalized && sub_fv.is_generalized() =>
{
Ok(())
}
(TyParam::FreeVar(sub_fv), sup_tp) => {
match &*sub_fv.borrow() {
FreeKind::Linked(l) | FreeKind::UndoableLinked { t: l, .. } => {
@ -366,7 +373,11 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
)))
}
}
(_, TyParam::FreeVar(sup_fv)) if sup_fv.is_generalized() => Ok(()),
(_, TyParam::FreeVar(sup_fv))
if !self.change_generalized && sup_fv.is_generalized() =>
{
Ok(())
}
(sub_tp, TyParam::FreeVar(sup_fv)) => {
match &*sup_fv.borrow() {
FreeKind::Linked(l) | FreeKind::UndoableLinked { t: l, .. } => {
@ -760,7 +771,8 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
(FreeVar(sub_fv), FreeVar(sup_fv))
if sub_fv.constraint_is_sandwiched() && sup_fv.constraint_is_sandwiched() =>
{
if sub_fv.is_generalized() || sup_fv.is_generalized() {
if !self.change_generalized && (sub_fv.is_generalized() || sup_fv.is_generalized())
{
log!(info "generalized:\nmaybe_sub: {maybe_sub}\nmaybe_sup: {maybe_sup}");
return Ok(());
}
@ -860,7 +872,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
},
FreeVar(sup_fv),
) if sup_fv.constraint_is_sandwiched() => {
if sup_fv.is_generalized() {
if !self.change_generalized && sup_fv.is_generalized() {
log!(info "generalized:\nmaybe_sub: {maybe_sub}\nmaybe_sup: {maybe_sup}");
return Ok(());
}
@ -958,7 +970,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
// e.g. Structural({ .method = (self: T) -> Int })/T
(Structural(sub), FreeVar(sup_fv))
if sup_fv.is_unbound() && sub.contains_tvar(sup_fv) => {}
(_, FreeVar(sup_fv)) if sup_fv.is_generalized() => {}
(_, FreeVar(sup_fv)) if !self.change_generalized && sup_fv.is_generalized() => {}
(_, FreeVar(sup_fv)) if sup_fv.is_unbound() => {
// * sub_unify(Nat, ?E(<: Eq(?E)))
// sub !<: l => OK (sub will widen)
@ -1037,7 +1049,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
(FreeVar(sub_fv), Ref(sup)) if sub_fv.is_unbound() => {
self.sub_unify(maybe_sub, sup)?;
}
(FreeVar(sub_fv), _) if sub_fv.is_generalized() => {}
(FreeVar(sub_fv), _) if !self.change_generalized && sub_fv.is_generalized() => {}
(FreeVar(sub_fv), _) if sub_fv.is_unbound() => {
// sub !<: r => Error
// * sub_unify(?T(:> Int, <: _), Nat): (/* Error */)
@ -1165,7 +1177,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
.iter()
.zip(sup_subr.non_default_params.iter())
.try_for_each(|(sub, sup)| {
if sub.typ().is_generalized() {
if !self.change_generalized && sub.typ().is_generalized() {
Ok(())
}
// contravariant
@ -1179,7 +1191,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
.iter()
.find(|sub_pt| sub_pt.name() == sup_pt.name())
{
if sup_pt.typ().is_generalized() {
if !self.change_generalized && sup_pt.typ().is_generalized() {
continue;
}
// contravariant
@ -1203,7 +1215,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
.zip(sup_subr.non_default_params.iter())
.try_for_each(|(sub, sup)| {
// contravariant
if sup.typ().is_generalized() {
if !self.change_generalized && sup.typ().is_generalized() {
Ok(())
} else {
self.sub_unify(sup.typ(), sub.typ())
@ -1216,7 +1228,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
.find(|sub_pt| sub_pt.name() == sup_pt.name())
{
// contravariant
if sup_pt.typ().is_generalized() {
if !self.change_generalized && sup_pt.typ().is_generalized() {
continue;
}
self.sub_unify(sup_pt.typ(), sub_pt.typ())?;
@ -1555,7 +1567,7 @@ impl Context {
maybe_sup: &Type,
loc: &impl Locational,
) -> TyCheckResult<()> {
let unifier = Unifier::new(self, loc, None, None);
let unifier = Unifier::new(self, loc, None, false, None);
unifier.occur(maybe_sub, maybe_sup)
}
@ -1567,7 +1579,7 @@ impl Context {
loc: &impl Locational,
is_structural: bool,
) -> TyCheckResult<()> {
let unifier = Unifier::new(self, loc, None, None);
let unifier = Unifier::new(self, loc, None, false, None);
unifier.sub_unify_tp(maybe_sub, maybe_sup, variance, is_structural)
}
@ -1579,7 +1591,19 @@ impl Context {
loc: &impl Locational,
param_name: Option<&Str>,
) -> TyCheckResult<()> {
let unifier = Unifier::new(self, loc, None, param_name.cloned());
let unifier = Unifier::new(self, loc, None, false, param_name.cloned());
unifier.sub_unify(maybe_sub, maybe_sup)
}
/// This will rewrite generalized type variables.
pub(crate) fn force_sub_unify(
&self,
maybe_sub: &Type,
maybe_sup: &Type,
loc: &impl Locational,
param_name: Option<&Str>,
) -> TyCheckResult<()> {
let unifier = Unifier::new(self, loc, None, true, param_name.cloned());
unifier.sub_unify(maybe_sub, maybe_sup)
}
@ -1591,12 +1615,12 @@ impl Context {
list: &UndoableLinkedList,
param_name: Option<&Str>,
) -> TyCheckResult<()> {
let unifier = Unifier::new(self, loc, Some(list), param_name.cloned());
let unifier = Unifier::new(self, loc, Some(list), false, param_name.cloned());
unifier.sub_unify(maybe_sub, maybe_sup)
}
pub(crate) fn unify(&self, lhs: &Type, rhs: &Type) -> Option<Type> {
let unifier = Unifier::new(self, &(), None, None);
let unifier = Unifier::new(self, &(), None, false, None);
unifier.unify(lhs, rhs)
}
}

View file

@ -1956,25 +1956,26 @@ impl ASTLowerer {
fn lower_subr_block(
&mut self,
subr_t: SubrType,
registered_subr_t: SubrType,
sig: ast::SubrSignature,
decorators: Set<hir::Expr>,
body: ast::DefBody,
) -> LowerResult<hir::Def> {
let params = self.lower_params(sig.params.clone(), Some(&subr_t))?;
let params = self.lower_params(sig.params.clone(), Some(&registered_subr_t))?;
if let Err(errs) = self.module.context.register_const(&body.block) {
self.errs.extend(errs);
}
let return_t = subr_t
let return_t = registered_subr_t
.return_t
.has_no_unbound_var()
.then_some(subr_t.return_t.as_ref());
.then_some(registered_subr_t.return_t.as_ref());
match self.lower_block(body.block, return_t) {
Ok(block) => {
let found_body_t = self.module.context.squash_tyvar(block.t());
let vi = match self.module.context.outer.as_mut().unwrap().assign_subr(
&sig,
body.id,
&params,
&found_body_t,
block.last().unwrap(),
) {
@ -2009,6 +2010,7 @@ impl ASTLowerer {
let vi = match self.module.context.outer.as_mut().unwrap().assign_subr(
&sig,
ast::DefId(0),
&params,
&Type::Failure,
&sig,
) {

View file

@ -25,3 +25,8 @@ f! t =
for! arr, t =>
result.extend! f! t
result
c_new x, y = C.new x, y
C = Class Int
C.
new x, y = Self::__new__ x + y

View file

@ -1,3 +1,5 @@
use std::vec;
use erg_common::config::ErgConfig;
use erg_common::error::MultiErrorDisplay;
use erg_common::io::Output;
@ -79,6 +81,12 @@ fn _test_infer_types() -> Result<(), ()> {
let t = type_q("T");
let f_t = proc1(t.clone(), unknown_len_array_mut(t)).quantify();
module.context.assert_var_type("f!", &f_t)?;
let r = type_q("R");
let add_r = poly("Add", vec![ty_tp(r.clone())]);
let c = mono("<module>::C");
let c_new_t = func2(add_r, r, c.clone()).quantify();
module.context.assert_var_type("c_new", &c_new_t)?;
module.context.assert_attr_type(&c, "new", &c_new_t)?;
Ok(())
}

View file

@ -23,7 +23,7 @@ static UNBOUND_ID: AtomicUsize = AtomicUsize::new(0);
pub trait HasLevel {
fn level(&self) -> Option<Level>;
fn set_level(&self, lev: Level);
fn lower(&self, level: Level) {
fn set_lower(&self, level: Level) {
if self.level() < Some(level) {
self.set_level(level);
}
@ -33,6 +33,11 @@ pub trait HasLevel {
self.set_level(lev.saturating_add(1));
}
}
fn lower(&self) {
if let Some(lev) = self.level() {
self.set_level(lev.saturating_sub(1));
}
}
fn generalize(&self) {
self.set_level(GENERIC_LEVEL);
}