chore: improve type inference system

This commit is contained in:
Shunsuke Shibayama 2023-10-18 16:46:06 +09:00
parent d0bae66450
commit 06898bd793
11 changed files with 332 additions and 43 deletions

View file

@ -654,11 +654,15 @@ impl Context {
ident: &Identifier,
input: &Input,
namespace: &Context,
expect: Option<&Type>,
) -> Triple<VarInfo, TyCheckError> {
// get_attr_info(?T, aaa) == None
// => ?T(<: Structural({ .aaa = ?U }))
if PYTHON_MODE && obj.var_info().is_some_and(|vi| vi.is_untyped_parameter()) {
let t = free_var(self.level, Constraint::new_type_of(Type));
let constraint = expect.map_or(Constraint::new_type_of(Type), |t| {
Constraint::new_subtype_of(t.clone())
});
let t = free_var(self.level, constraint);
if let Some(fv) = obj.ref_t().as_free() {
if fv.get_sub().is_some() {
let vis = self.instantiate_vis_modifier(&ident.vis).unwrap();
@ -1666,22 +1670,32 @@ impl Context {
subr.non_default_params.iter()
};
let non_default_params_len = non_default_params.len();
let mut nth = 1;
if pos_args.len() >= non_default_params_len {
let (non_default_args, var_args) = pos_args.split_at(non_default_params_len);
for (nd_arg, nd_param) in non_default_args.iter().zip(non_default_params) {
let mut args = non_default_args
.iter()
.zip(non_default_params)
.enumerate()
.collect::<Vec<_>>();
// TODO: remove `obj.local_name() != Some("__contains__")`
if obj.local_name() != Some("__contains__") && subr.has_unbound_var() {
args.sort_by(|(_, (l, _)), (_, (r, _))| {
l.expr.complexity().cmp(&r.expr.complexity())
});
}
for (i, (nd_arg, nd_param)) in args {
if let Err(mut es) = self.substitute_pos_arg(
&callee,
attr_name,
&nd_arg.expr,
nth,
i + 1,
nd_param,
&mut passed_params,
) {
errs.append(&mut es);
}
nth += 1;
}
let mut nth = 1 + non_default_params_len;
if let Some(var_param) = subr.var_params.as_ref() {
for var_arg in var_args.iter() {
if let Err(mut es) = self.substitute_var_arg(
@ -1736,7 +1750,9 @@ impl Context {
}
}
} else {
let mut nth = 1;
// pos_args.len() < non_default_params_len
// don't use `zip`
let mut params = non_default_params.chain(subr.default_params.iter());
for pos_arg in pos_args.iter() {
if let Err(mut es) = self.substitute_pos_arg(

View file

@ -774,6 +774,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
for (lps, rps) in lsub.typarams().iter().zip(rsub.typarams().iter()) {
self.sub_unify_tp(lps, rps, None, false).map_err(|errs| {
sup_fv.undo();
sub_fv.undo();
errs
})?;
}
@ -784,6 +785,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
for (lps, rps) in lsup.typarams().iter().zip(rsup.typarams().iter()) {
self.sub_unify_tp(lps, rps, None, false).map_err(|errs| {
sup_fv.undo();
sub_fv.undo();
errs
})?;
}

View file

@ -515,11 +515,13 @@ impl ASTLowerer {
fn fake_lower_lambda(&self, lambda: ast::Lambda) -> LowerResult<hir::Lambda> {
let params = self.fake_lower_params(lambda.sig.params)?;
let return_t_spec = lambda.sig.return_t_spec.map(|t_spec| t_spec.t_spec);
let body = self.fake_lower_block(lambda.body)?;
Ok(hir::Lambda::new(
lambda.id.0,
params,
lambda.op,
return_t_spec,
body,
Type::Failure,
))

View file

@ -1910,6 +1910,13 @@ impl Params {
)
}
pub fn sigs(&self) -> impl Iterator<Item = &NonDefaultParamSignature> {
self.non_defaults
.iter()
.chain(self.var_params.as_deref())
.chain(self.defaults.iter().map(|d| &d.sig))
}
pub fn deconstruct(self) -> RawParams {
(
self.non_defaults,
@ -2015,6 +2022,7 @@ impl SubrSignature {
pub struct Lambda {
pub params: Params,
pub op: Token,
pub return_t_spec: Option<TypeSpec>,
pub body: Block,
pub id: usize,
pub t: Type,
@ -2043,11 +2051,19 @@ impl_locational!(Lambda, params, body);
impl_t!(Lambda);
impl Lambda {
pub const fn new(id: usize, params: Params, op: Token, body: Block, t: Type) -> Self {
pub const fn new(
id: usize,
params: Params,
op: Token,
return_t_spec: Option<TypeSpec>,
body: Block,
t: Type,
) -> Self {
Self {
id,
params,
op,
return_t_spec,
body,
t,
}
@ -2778,6 +2794,82 @@ impl Expr {
pub fn type_asc_expr(self, t_spec: TypeSpecWithOp) -> Self {
Self::TypeAsc(self.type_asc(t_spec))
}
/// Return the complexity of the expression in terms of type inference.
/// For function calls, type inference is performed sequentially, starting with the least complex argument.
pub fn complexity(&self) -> usize {
match self {
Self::Literal(_) | Self::TypeAsc(_) => 0,
Self::Accessor(Accessor::Ident(_)) => 1,
Self::Accessor(Accessor::Attr(attr)) => 1 + attr.obj.complexity(),
Self::Tuple(Tuple::Normal(tup)) => {
let mut sum = 0;
for elem in tup.elems.pos_args.iter() {
sum += elem.expr.complexity();
}
sum
}
Self::Array(Array::Normal(arr)) => {
let mut sum = 0;
for elem in arr.elems.pos_args.iter() {
sum += elem.expr.complexity();
}
sum
}
Self::Dict(Dict::Normal(dic)) => {
let mut sum = 0;
for kv in dic.kvs.iter() {
sum += kv.key.complexity();
sum += kv.value.complexity();
}
sum
}
Self::Set(Set::Normal(set)) => {
let mut sum = 0;
for elem in set.elems.pos_args.iter() {
sum += elem.expr.complexity();
}
sum
}
Self::Record(rec) => {
let mut sum = 0;
for attr in rec.attrs.iter() {
for chunk in attr.body.block.iter() {
sum += chunk.complexity();
}
}
sum
}
Self::BinOp(bin) => 1 + bin.lhs.complexity() + bin.rhs.complexity(),
Self::UnaryOp(unary) => 1 + unary.expr.complexity(),
Self::Call(call) => {
let mut sum = 1 + call.obj.complexity();
for arg in call.args.pos_args.iter() {
sum += arg.expr.complexity();
}
if let Some(var_params) = call.args.var_args.as_ref() {
sum += var_params.expr.complexity();
}
for kw_arg in call.args.kw_args.iter() {
sum += kw_arg.expr.complexity();
}
sum
}
Self::Lambda(lambda) => {
let mut sum = 1
+ lambda.return_t_spec.is_none() as usize
+ lambda
.params
.sigs()
.fold(0, |acc, sig| acc + sig.raw.t_spec.is_none() as usize);
for chunk in lambda.body.iter() {
sum += chunk.complexity();
}
sum
}
_ => 5,
}
}
}
/// Toplevel grammar unit

View file

@ -691,11 +691,12 @@ impl ASTLowerer {
.convert_type_to_dict_type(exp.clone())
.ok()
})
.map(|dict| dict.into_iter().collect::<Vec<_>>());
let expect_kvs = expect.transpose(dict.kvs.len());
for (kv, expect) in dict.kvs.into_iter().zip(expect_kvs) {
let key = self.lower_expr(kv.key, expect.as_ref().map(|(k, _)| k))?;
let value = self.lower_expr(kv.value, expect.as_ref().map(|(_, v)| v))?;
.and_then(|dict| (dict.len() == 1).then_some(dict));
for kv in dict.kvs.into_iter() {
let expect_key = expect.as_ref().map(|dict| dict.keys().next().unwrap());
let expect_value = expect.as_ref().map(|dict| dict.values().next().unwrap());
let key = self.lower_expr(kv.key, expect_key)?;
let value = self.lower_expr(kv.value, expect_value)?;
if let Some(popped_val_t) = union.insert(key.t(), value.t()) {
if PYTHON_MODE {
let val_t = union.get_mut(key.ref_t()).unwrap();
@ -792,6 +793,7 @@ impl ASTLowerer {
&attr.ident,
&self.cfg.input,
&self.module.context,
expect,
) {
Triple::Ok(vi) => {
self.inc_ref(attr.ident.inspect(), &vi, &attr.ident.name);
@ -822,6 +824,15 @@ impl ASTLowerer {
VarInfo::ILLEGAL
}
};
if let Some(expect) = expect {
if let Err(_errs) =
self.module
.context
.sub_unify(&vi.t, expect, &attr.ident.loc(), None)
{
// self.errs.extend(errs);
}
}
let ident = hir::Identifier::new(attr.ident, None, vi);
let acc = hir::Accessor::Attr(hir::Attribute::new(obj, ident));
Ok(acc)
@ -1076,20 +1087,38 @@ impl ASTLowerer {
.map_or(vec![None; pos_args.len()], |params| {
params.take(pos_args.len()).collect()
});
for (nth, (arg, param)) in pos_args.into_iter().zip(pos_params).enumerate() {
let mut pos_args = pos_args
.into_iter()
.zip(pos_params)
.enumerate()
.collect::<Vec<_>>();
// `if` may perform assert casting, so don't sort args
if self
.module
.context
.control_kind()
.map_or(true, |kind| !kind.is_if())
&& expect.is_some_and(|subr| subr.has_unbound_var())
{
pos_args
.sort_by(|(_, (l, _)), (_, (r, _))| l.expr.complexity().cmp(&r.expr.complexity()));
}
let mut hir_pos_args =
vec![hir::PosArg::new(hir::Expr::Dummy(hir::Dummy::empty())); pos_args.len()];
for (nth, (arg, param)) in pos_args {
match self.lower_expr(arg.expr, param) {
Ok(expr) => {
if let Some(kind) = self.module.context.control_kind() {
self.push_guard(nth, kind, expr.ref_t());
}
hir_args.pos_args.push(hir::PosArg::new(expr))
hir_pos_args[nth] = hir::PosArg::new(expr);
}
Err(es) => {
errs.extend(es);
hir_args.push_pos(hir::PosArg::new(hir::Expr::Dummy(hir::Dummy::empty())));
}
}
}
hir_args.pos_args = hir_pos_args;
// TODO: expect var_args
if let Some(var_args) = var_args {
match self.lower_expr(var_args.expr, None) {
@ -1121,6 +1150,12 @@ impl ASTLowerer {
hir_args
}
/// ```erg
/// x: Int or NoneType
/// if x != None:
/// do: ... # x: Int (x != None)
/// do: ... # x: NoneType (complement(x != None))
/// ```
fn push_guard(&mut self, nth: usize, kind: ControlKind, t: &Type) {
match t {
Type::Guard(guard) => match nth {
@ -1196,6 +1231,17 @@ impl ASTLowerer {
.as_ref()
.ok()
.and_then(|vi| <&SubrType>::try_from(&vi.t).ok());
if let Some((subr_return_t, expect)) =
expect_subr.map(|subr| subr.return_t.as_ref()).zip(expect)
{
if let Err(_errs) = self
.module
.context
.sub_unify(subr_return_t, expect, &(), None)
{
// self.errs.extend(errs);
}
}
let hir_args = self.lower_args(call.args, expect_subr, &mut errs);
let mut vi = match self.module.context.get_call_t(
&obj,
@ -1245,11 +1291,15 @@ impl ASTLowerer {
None
};
let mut call = hir::Call::new(obj, attr_name, hir_args);
/*if let Some((found, expect)) = call.signature_t().and_then(|sig| sig.return_t()).zip(expect) {
if let Err(errs) = self.module.context.sub_unify(found, expect, &call, None) {
self.errs.extend(errs);
if let Some((found, expect)) = call
.signature_t()
.and_then(|sig| sig.return_t())
.zip(expect)
{
if let Err(_errs) = self.module.context.sub_unify(found, expect, &call, None) {
// self.errs.extend(errs);
}
}*/
}
if pushed {
self.module.context.higher_order_caller.pop();
}
@ -1464,6 +1514,7 @@ impl ASTLowerer {
expect: Option<&Type>,
) -> LowerResult<hir::Lambda> {
let expect = expect.and_then(|t| <&SubrType>::try_from(t).ok());
let return_t = expect.map(|subr| subr.return_t.as_ref());
log!(info "entered {}({lambda})", fn_name!());
let in_statement = PYTHON_MODE
&& self
@ -1513,7 +1564,7 @@ impl ASTLowerer {
if let Err(errs) = self.module.context.register_const(&lambda.body) {
self.errs.extend(errs);
}
let body = self.lower_block(lambda.body, None).map_err(|errs| {
let body = self.lower_block(lambda.body, return_t).map_err(|errs| {
if !in_statement {
self.pop_append_errs();
}
@ -1646,7 +1697,15 @@ impl ASTLowerer {
} else {
ty
};
Ok(hir::Lambda::new(id, params, lambda.op, body, t))
let return_t_spec = lambda.sig.return_t_spec.map(|t_spec| t_spec.t_spec);
Ok(hir::Lambda::new(
id,
params,
lambda.op,
return_t_spec,
body,
t,
))
}
fn lower_def(&mut self, def: ast::Def) -> LowerResult<hir::Def> {
@ -1851,7 +1910,11 @@ impl ASTLowerer {
if let Err(errs) = self.module.context.register_const(&body.block) {
self.errs.extend(errs);
}
match self.lower_block(body.block, None) {
let return_t = subr_t
.return_t
.has_no_unbound_var()
.then_some(subr_t.return_t.as_ref());
match self.lower_block(body.block, return_t) {
Ok(block) => {
let found_body_t = self.module.context.squash_tyvar(block.t());
let vi = match self.module.context.outer.as_mut().unwrap().assign_subr(

View file

@ -391,8 +391,7 @@ impl SubrType {
|| self
.var_params
.as_ref()
.map(|pt| pt.typ().contains_tvar(target))
.unwrap_or(false)
.map_or(false, |pt| pt.typ().contains_tvar(target))
|| self
.default_params
.iter()
@ -407,8 +406,7 @@ impl SubrType {
|| self
.var_params
.as_ref()
.map(|pt| pt.typ().contains_type(target))
.unwrap_or(false)
.map_or(false, |pt| pt.typ().contains_type(target))
|| self
.default_params
.iter()
@ -423,8 +421,7 @@ impl SubrType {
|| self
.var_params
.as_ref()
.map(|pt| pt.typ().contains_tp(target))
.unwrap_or(false)
.map_or(false, |pt| pt.typ().contains_tp(target))
|| self
.default_params
.iter()
@ -475,12 +472,26 @@ impl SubrType {
|| self
.var_params
.as_ref()
.map(|pt| pt.typ().has_qvar())
.unwrap_or(false)
.map_or(false, |pt| pt.typ().has_qvar())
|| self.default_params.iter().any(|pt| pt.typ().has_qvar())
|| self.return_t.has_qvar()
}
pub fn has_unbound_var(&self) -> bool {
self.non_default_params
.iter()
.any(|pt| pt.typ().has_unbound_var())
|| self
.var_params
.as_ref()
.map_or(false, |pt| pt.typ().has_unbound_var())
|| self
.default_params
.iter()
.any(|pt| pt.typ().has_unbound_var())
|| self.return_t.has_unbound_var()
}
pub fn has_undoable_linked_var(&self) -> bool {
self.non_default_params
.iter()
@ -488,8 +499,7 @@ impl SubrType {
|| self
.var_params
.as_ref()
.map(|pt| pt.typ().has_undoable_linked_var())
.unwrap_or(false)
.map_or(false, |pt| pt.typ().has_undoable_linked_var())
|| self
.default_params
.iter()
@ -519,8 +529,7 @@ impl SubrType {
pub fn self_t(&self) -> Option<&Type> {
self.non_default_params.first().and_then(|p| {
if p.name()
.map(|n| &n[..] == "self" || &n[..] == "Self")
.unwrap_or(false)
.map_or(false, |n| &n[..] == "self" || &n[..] == "Self")
{
Some(p.typ())
} else {
@ -532,8 +541,7 @@ impl SubrType {
pub fn mut_self_t(&mut self) -> Option<&mut Type> {
self.non_default_params.first_mut().and_then(|p| {
if p.name()
.map(|n| &n[..] == "self" || &n[..] == "Self")
.unwrap_or(false)
.map_or(false, |n| &n[..] == "self" || &n[..] == "Self")
{
Some(p.typ_mut())
} else {

View file

@ -4239,6 +4239,13 @@ impl Params {
self.non_defaults.len() + self.defaults.len()
}
pub fn sigs(&self) -> impl Iterator<Item = &NonDefaultParamSignature> {
self.non_defaults
.iter()
.chain(self.var_params.as_deref())
.chain(self.defaults.iter().map(|d| &d.sig))
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
@ -5040,6 +5047,83 @@ impl Expr {
pub fn unary_op(self, op: Token) -> UnaryOp {
UnaryOp::new(op, self)
}
/// Return the complexity of the expression in terms of type inference.
/// For function calls, type inference is performed sequentially, starting with the least complex argument.
pub fn complexity(&self) -> usize {
match self {
Self::Literal(_) | Self::TypeAscription(_) => 0,
Self::Accessor(Accessor::Ident(_)) => 1,
Self::Accessor(Accessor::Attr(attr)) => 1 + attr.obj.complexity(),
Self::Tuple(Tuple::Normal(tup)) => {
let mut sum = 0;
for elem in tup.elems.pos_args.iter() {
sum += elem.expr.complexity();
}
sum
}
Self::Array(Array::Normal(arr)) => {
let mut sum = 0;
for elem in arr.elems.pos_args.iter() {
sum += elem.expr.complexity();
}
sum
}
Self::Dict(Dict::Normal(dic)) => {
let mut sum = 0;
for kv in dic.kvs.iter() {
sum += kv.key.complexity();
sum += kv.value.complexity();
}
sum
}
Self::Set(Set::Normal(set)) => {
let mut sum = 0;
for elem in set.elems.pos_args.iter() {
sum += elem.expr.complexity();
}
sum
}
Self::Record(Record::Normal(rec)) => {
let mut sum = 0;
for attr in rec.attrs.iter() {
for chunk in attr.body.block.iter() {
sum += chunk.complexity();
}
}
sum
}
Self::BinOp(bin) => 1 + bin.args[0].complexity() + bin.args[1].complexity(),
Self::UnaryOp(unary) => 1 + unary.args[0].complexity(),
Self::Call(call) => {
let mut sum = 1 + call.obj.complexity();
for arg in call.args.pos_args.iter() {
sum += arg.expr.complexity();
}
if let Some(var_params) = call.args.var_args.as_ref() {
sum += var_params.expr.complexity();
}
for kw_arg in call.args.kw_args.iter() {
sum += kw_arg.expr.complexity();
}
sum
}
Self::Lambda(lambda) => {
let mut sum = 1
+ lambda.sig.return_t_spec.is_none() as usize
+ lambda
.sig
.params
.sigs()
.fold(0, |acc, sig| acc + sig.t_spec.is_none() as usize);
for chunk in lambda.body.iter() {
sum += chunk.complexity();
}
sum
}
_ => 5,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]

View file

@ -7,6 +7,7 @@ use erg_common::error::{Location, MultiErrorDisplay};
use erg_common::io::{DummyStdin, Input, Output};
use erg_common::python_util::PythonVersion;
use erg_common::spawn::exec_new_thread;
use erg_common::style::remove_style;
use erg_common::style::{colors::DEBUG_MAIN, RESET};
use erg_common::traits::{ExitStatus, Runnable, Stream};
@ -192,9 +193,9 @@ pub(crate) fn expect_failure(
}
}
pub(crate) fn expect_error_location(
pub(crate) fn expect_error_location_and_msg(
file_path: &'static str,
locs: Vec<Location>,
locs: Vec<(Location, Option<&str>)>,
) -> Result<(), ()> {
match exec_compiler(file_path) {
Ok(_) => {
@ -202,7 +203,7 @@ pub(crate) fn expect_error_location(
Err(())
}
Err(errs) => {
for (err, loc) in errs.into_iter().zip(locs) {
for (err, (loc, msg)) in errs.into_iter().zip(locs) {
if err.core.loc != loc {
println!(
"err[{file_path}]: error location should be {loc}, but got {}",
@ -210,6 +211,13 @@ pub(crate) fn expect_error_location(
);
return Err(());
}
if msg.is_some_and(|m| remove_style(&err.core.main_message) != m) {
println!(
"err[{file_path}]: error message should be {:?}, but got {:?}",
msg, err.core.main_message
);
return Err(());
}
}
Ok(())
}

View file

@ -8,3 +8,6 @@ for! zip([1+1], ["a"+"b"]), ((i, s),) => # i: Nat, s: Str
for! {"a": 1}, s =>
print! s + 1 # ERR
arr as Array(Int) = [1, 2]
_ = all map((i) -> i.method(), arr) # ERR

View file

@ -23,3 +23,9 @@ assert dic in {Str: Str}
.f dic: {Str: Str or Array(Str)} =
assert dic["key"] in Str # Required to pass the check on the next line
assert dic["key"] in {"a", "b", "c"}
assert dic["key2"] in Array(Str)
b as Bytes or NoneType = bytes "aaa", "utf-8"
_ = if b != None:
do b.decode("utf-8")
do ""

View file

@ -1,6 +1,7 @@
mod common;
use common::{
expect_compile_success, expect_end_with, expect_error_location, expect_failure, expect_success,
expect_compile_success, expect_end_with, expect_error_location_and_msg, expect_failure,
expect_success,
};
use erg_common::error::Location;
use erg_common::python_util::{env_python_version, module_exists, opt_which_python};
@ -599,12 +600,16 @@ fn exec_visibility() -> Result<(), ()> {
#[test]
fn exec_err_loc() -> Result<(), ()> {
expect_error_location(
expect_error_location_and_msg(
"tests/should_err/err_loc.er",
vec![
Location::range(2, 11, 2, 16),
Location::range(7, 15, 7, 18),
Location::range(10, 11, 10, 16),
(Location::range(2, 11, 2, 16), None),
(Location::range(7, 11, 7, 12), None),
(
Location::range(13, 21, 13, 27),
Some("Int object has no attribute method"),
),
(Location::range(10, 11, 10, 16), None),
],
)
}