fix: array type inffering

This commit is contained in:
Shunsuke Shibayama 2023-10-16 23:16:39 +09:00
parent 7cd3bce9f1
commit 5096843bc5
7 changed files with 142 additions and 51 deletions

View file

@ -503,6 +503,14 @@ macro_rules! log {
}
}};
(caller) => {{
if cfg!(feature = "debug") {
use $crate::style::*;
$crate::debug_info!();
println!("\n{}", std::panic::Location::caller());
}
}};
($($arg: tt)*) => {{
if cfg!(feature = "debug") {
use $crate::style::*;

View file

@ -23,7 +23,7 @@ use crate::ty::constructors::{
array_t, bounded, dict_t, mono, mono_q, named_free_var, poly, proj, proj_call, ref_, ref_mut,
refinement, set_t, subr_t, subtypeof, tp_enum, tuple_t, unknown_len_array_t, v_enum,
};
use crate::ty::free::{Constraint, HasLevel};
use crate::ty::free::HasLevel;
use crate::ty::typaram::{OpKind, TyParam};
use crate::ty::value::{GenTypeObj, TypeObj, ValueObj};
use crate::ty::{
@ -1561,9 +1561,9 @@ impl Context {
} else {
self.eval_t_params(sup, level, t_loc)?
};
let new_constraint = Constraint::new_sandwiched(sub, sup);
fv.update_constraint(new_constraint, false);
Ok(Type::FreeVar(fv))
let fv = Type::FreeVar(fv);
fv.update_tyvar(sub, sup, None, false);
Ok(fv)
}
Type::Subr(mut subr) => {
for pt in subr.non_default_params.iter_mut() {

View file

@ -353,30 +353,7 @@ impl ASTLowerer {
for elem in elems.into_iter() {
let elem = self.lower_expr(elem.expr, expect_elem.as_ref())?;
let union_ = self.module.context.union(&union, elem.ref_t());
if let Some((l, r)) = union_.union_pair() {
match (l.is_unbound_var(), r.is_unbound_var()) {
// e.g. [1, "a"]
(false, false) => {
if let hir::Expr::TypeAsc(type_asc) = &elem {
// e.g. [1, "a": Str or NoneType]
if ERG_MODE
&& !self
.module
.context
.supertype_of(&type_asc.spec.spec_t, &union)
{
return Err(self.elem_err(&l, &r, &elem));
} // else(OK): e.g. [1, "a": Str or Int]
} else if ERG_MODE {
return Err(self.elem_err(&l, &r, &elem));
}
}
// TODO: check if the type is compatible with the other type
(true, false) => {}
(false, true) => {}
(true, true) => {}
}
}
self.homogeneity_check(expect_elem.as_ref(), &union_, &union, &elem)?;
union = union_;
new_array.push(elem);
}
@ -398,6 +375,41 @@ impl ASTLowerer {
Ok(hir::NormalArray::new(array.l_sqbr, array.r_sqbr, t, elems))
}
fn homogeneity_check(
&self,
expect_elem: Option<&Type>,
union_: &Type,
union: &Type,
elem: &hir::Expr,
) -> LowerResult<()> {
if ERG_MODE && expect_elem.is_none() {
if let Some((l, r)) = union_.union_pair() {
match (l.is_unbound_var(), r.is_unbound_var()) {
// e.g. [1, "a"]
(false, false) => {
if let hir::Expr::TypeAsc(type_asc) = elem {
// e.g. [1, "a": Str or NoneType]
if !self
.module
.context
.supertype_of(&type_asc.spec.spec_t, union)
{
return Err(self.elem_err(&l, &r, elem));
} // else(OK): e.g. [1, "a": Str or Int]
} else {
return Err(self.elem_err(&l, &r, elem));
}
}
// TODO: check if the type is compatible with the other type
(true, false) => {}
(false, true) => {}
(true, true) => {}
}
}
}
Ok(())
}
fn lower_array_with_length(
&mut self,
array: ast::ArrayWithLength,
@ -1731,23 +1743,24 @@ impl ASTLowerer {
if let Err(errs) = self.module.context.register_const(&body.block) {
self.errs.extend(errs);
}
match self.lower_block(body.block, None) {
let outer = self.module.context.outer.as_ref().unwrap();
let expect_body_t = sig
.t_spec
.as_ref()
.and_then(|t_spec| {
self.module
.context
.instantiate_var_sig_t(Some(&t_spec.t_spec), RegistrationMode::PreRegister)
.ok()
})
.or_else(|| {
sig.ident()
.and_then(|ident| outer.get_current_scope_var(&ident.name))
.map(|vi| vi.t.clone())
});
match self.lower_block(body.block, expect_body_t.as_ref()) {
Ok(block) => {
let found_body_t = block.ref_t();
let outer = self.module.context.outer.as_ref().unwrap();
let opt_expect_body_t = self
.module
.context
.instantiate_var_sig_t(
sig.t_spec.as_ref().map(|ts| &ts.t_spec),
RegistrationMode::PreRegister,
)
.ok()
.or_else(|| {
sig.ident()
.and_then(|ident| outer.get_current_scope_var(&ident.name))
.map(|vi| vi.t.clone())
});
let ident = match &sig.pat {
ast::VarPattern::Ident(ident) => ident.clone(),
ast::VarPattern::Discard(token) => {
@ -1755,7 +1768,7 @@ impl ASTLowerer {
}
_ => unreachable!(),
};
if let Some(expect_body_t) = opt_expect_body_t {
if let Some(expect_body_t) = expect_body_t {
// TODO: expect_body_t is smaller for constants
// TODO: 定数の場合、expect_body_tのほうが小さくなってしまう
if !sig.is_const() {

View file

@ -6,7 +6,7 @@ use std::sync::atomic::AtomicUsize;
use erg_common::shared::Forkable;
use erg_common::traits::{LimitedDisplay, StructuralEq};
use erg_common::Str;
use erg_common::{addr, Str};
use erg_common::{addr_eq, log};
use super::typaram::TyParam;
@ -556,7 +556,9 @@ impl Hash for Free<TyParam> {
if let Some(lev) = self.level() {
lev.hash(state);
}
if let Some(t) = self.get_type() {
if self.is_recursive() {
addr!(self).hash(state);
} else if let Some(t) = self.get_type() {
t.hash(state);
} else if self.is_linked() {
self.crack().hash(state);
@ -691,6 +693,39 @@ impl Free<TyParam> {
self.constraint().unwrap(),
)
}
pub fn is_recursive(&self) -> bool {
TyParam::FreeVar(self.clone()).is_recursive()
}
fn _do_avoiding_recursion<O, F: FnOnce() -> O>(
&self,
placeholder: Option<&TyParam>,
f: F,
) -> O {
let placeholder = placeholder.unwrap_or(&TyParam::Failure);
let is_recursive = self.is_recursive();
if is_recursive {
self.undoable_link(placeholder);
}
let res = f();
if is_recursive {
self.undo();
}
res
}
pub fn do_avoiding_recursion<O, F: FnOnce() -> O>(&self, f: F) -> O {
self._do_avoiding_recursion(None, f)
}
pub fn do_avoiding_recursion_with<O, F: FnOnce() -> O>(
&self,
placeholder: &TyParam,
f: F,
) -> O {
self._do_avoiding_recursion(Some(placeholder), f)
}
}
impl<T: StructuralEq + CanbeFree + Clone + Default + fmt::Debug + Send + Sync + 'static>

View file

@ -1939,6 +1939,14 @@ impl Type {
}
}
pub fn is_projection(&self) -> bool {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_projection(),
Self::Proj { .. } | Self::ProjCall { .. } => true,
_ => false,
}
}
pub fn is_intersection_type(&self) -> bool {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_intersection_type(),
@ -2286,10 +2294,9 @@ impl Type {
pub fn is_recursive(&self) -> bool {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_recursive(),
Self::FreeVar(fv) => fv
.get_subsup()
.map(|(sub, sup)| sub.contains_type(self) || sup.contains_type(self))
.unwrap_or(false),
Self::FreeVar(fv) => fv.get_subsup().map_or(false, |(sub, sup)| {
sub.contains_type(self) || sup.contains_type(self)
}),
Self::Record(rec) => rec.iter().any(|(_, t)| t.contains_type(self)),
Self::NamedTuple(rec) => rec.iter().any(|(_, t)| t.contains_type(self)),
Self::Poly { params, .. } => params.iter().any(|tp| tp.contains_type(self)),

View file

@ -1249,6 +1249,33 @@ impl TyParam {
}
}
pub fn is_recursive(&self) -> bool {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_recursive(),
Self::FreeVar(fv) => fv.get_type().map_or(false, |t| t.contains_tp(self)),
Self::Proj { obj, .. } => obj.contains_tp(self),
Self::ProjCall { obj, args, .. } => {
obj.contains_tp(self) || args.iter().any(|t| t.contains_tp(self))
}
Self::BinOp { lhs, rhs, .. } => lhs.contains_tp(self) || rhs.contains_tp(self),
Self::UnaryOp { val, .. } => val.contains_tp(self),
Self::App { args, .. } => args.iter().any(|t| t.contains_tp(self)),
Self::Lambda(lambda) => lambda.body.iter().any(|t| t.contains_tp(self)),
Self::Array(ts) | Self::Tuple(ts) => ts.iter().any(|t| t.contains_tp(self)),
Self::UnsizedArray(elem) => elem.contains_tp(self),
Self::Set(ts) => ts.iter().any(|t| t.contains_tp(self)),
Self::Record(rec) => rec.iter().any(|(_, t)| t.contains_tp(self)),
Self::Dict(ts) => ts
.iter()
.any(|(k, v)| k.contains_tp(self) || v.contains_tp(self)),
Self::DataClass { fields, .. } => fields.iter().any(|(_, t)| t.contains_tp(self)),
Self::Type(t) => t.contains_tp(self),
Self::Value(ValueObj::Type(t)) => t.typ().contains_tp(self),
Self::Erased(t) => t.contains_tp(self),
_ => false,
}
}
pub fn is_unbound_var(&self) -> bool {
match self {
Self::FreeVar(fv) => fv.is_unbound() || fv.crack().is_unbound_var(),

View file

@ -14,4 +14,5 @@ g2: (|T|(_: Structural { .a = (self: T) -> Obj }) -> NoneType) = f2
_, _ = f2, g2
Packages = [{ .name = Str; .version = Str }; _]
_: Packages = [{ .name = "a"; .version = "b" }]
_: Packages = [{ .name = "a"; .version = "1.0.0" }]
_: Packages = [{ .name = "a"; .version = "1.0.0" }, { .name = "b"; .version = "1.0.0" }]