fix: {default, variable} parameter bugs

This commit is contained in:
Shunsuke Shibayama 2024-04-24 15:24:05 +09:00
parent dc9cbf98a7
commit 3d7283cb01
8 changed files with 131 additions and 16 deletions

View file

@ -1,4 +1,5 @@
//! provides type-comparison
use std::iter::repeat;
use std::option::Option; // conflicting to Type::Option
use erg_common::consts::DEBUG_MODE;
@ -357,6 +358,7 @@ impl Context {
match (lhs, rhs) {
// Proc :> Func if params are compatible
// * default params can be omitted (e.g. (Int, x := Int) -> Int <: (Int) -> Int)
// * and default params can be non-default (e.g. (Int, x := Int) -> Int <: (Int, Int) -> Int)
(Subr(ls), Subr(rs)) if ls.kind == rs.kind || ls.kind.is_proc() => {
let default_check = || {
for lpt in ls.default_params.iter() {
@ -376,8 +378,12 @@ impl Context {
};
// () -> Never <: () -> Int <: () -> Object
// (Object) -> Int <: (Int) -> Int <: (Never) -> Int
// (Int, n := Int) -> Int <: (Int, Int) -> Int
// (Int, n := Int, m := Int) -> Int <: (Int, Int) -> Int
// (Int, n := Int) -> Int <!: (Int, Int, Int) -> Int
// (*Int) -> Int <: (Int, Int) -> Int
let same_params_len = ls.non_default_params.len() == rs.non_default_params.len()
let len_judge = ls.non_default_params.len()
<= rs.non_default_params.len() + rs.default_params.len()
|| rs.var_params.is_some();
// && ls.default_params.len() <= rs.default_params.len();
let rhs_ret = rs
@ -385,18 +391,24 @@ impl Context {
.clone()
.replace_params(rs.param_names(), ls.param_names());
let return_t_judge = self.supertype_of(&ls.return_t, &rhs_ret); // covariant
let non_defaults_judge = ls
.non_default_params
.iter()
.zip(rs.non_default_params.iter())
.all(|(l, r)| self.subtype_of(l.typ(), r.typ()));
let non_defaults_judge = if let Some(r_var) = rs.var_params.as_deref() {
ls.non_default_params
.iter()
.zip(repeat(r_var))
.all(|(l, r)| self.subtype_of(l.typ(), r.typ()))
} else {
ls.non_default_params
.iter()
.zip(rs.non_default_params.iter().chain(rs.default_params.iter()))
.all(|(l, r)| self.subtype_of(l.typ(), r.typ()))
};
let var_params_judge = ls
.var_params
.as_ref()
.zip(rs.var_params.as_ref())
.map(|(l, r)| self.subtype_of(l.typ(), r.typ()))
.unwrap_or(true);
same_params_len
len_judge
&& return_t_judge
&& non_defaults_judge
&& var_params_judge

View file

@ -1,4 +1,5 @@
//! provides type variable related operations
use std::iter::repeat;
use std::mem;
use std::option::Option;
@ -1252,14 +1253,28 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
}
(Subr(sub_subr), Subr(sup_subr)) => {
sub_subr
.non_default_params
.iter()
.zip(sup_subr.non_default_params.iter())
.try_for_each(|(sub, sup)| {
// contravariant
self.sub_unify(sup.typ(), sub.typ())
})?;
// (Int, *Int) -> ... <: (T, U, V) -> ...
if let Some(sub_var) = sub_subr.var_params.as_deref() {
sub_subr
.non_default_params
.iter()
.chain(repeat(sub_var))
.zip(sup_subr.non_default_params.iter())
.try_for_each(|(sub, sup)| {
// contravariant
self.sub_unify(sup.typ(), sub.typ())
})?;
} else {
sub_subr
.non_default_params
.iter()
.chain(sub_subr.default_params.iter())
.zip(sup_subr.non_default_params.iter())
.try_for_each(|(sub, sup)| {
// contravariant
self.sub_unify(sup.typ(), sub.typ())
})?;
}
sub_subr
.var_params
.iter()

View file

@ -0,0 +1,62 @@
.Differ: ClassType
.HTMLDiff: ClassType
.SequenceMatcher: ClassType
.context_diff: (
a: Sequence(Str),
b: Sequence(Str),
fromfile := Str,
tofile := Str,
fromfiledate := Str,
tofiledate := Str,
n := Nat,
lineterm := Str,
) -> Iterator Str
.get_close_matches: (
word: Str,
possibilities: Sequence(Str),
n := Nat,
cutoff := Float,
) -> [Str; _]
.ndiff: (
a: Sequence(Str),
b: Sequence(Str),
linejunk := Str -> Bool,
charjunk := Str -> Bool,
) -> Iterator Str
.restore: (sequence: Iterator(Str), which: Nat) -> Str
.unified_diff: (
a: Sequence(Str),
b: Sequence(Str),
fromfile := Str,
tofile := Str,
fromfiledate := Str,
tofiledate := Str,
n := Nat,
lineterm := Str,
) -> Iterator Str
.diff_bytes: (
dfunc: ((
a: Sequence(Str),
b: Sequence(Str),
fromfile: Str,
tofile: Str,
fromfiledate: Str,
tofiledate: Str,
n: Nat,
lineterm: Str,
) -> Iterator Str),
a: Sequence(Bytes),
b: Sequence(Bytes),
fromfile := Bytes,
tofile := Bytes,
fromfiledate := Bytes,
tofiledate := Bytes,
n := Nat,
lineterm := Bytes,
) -> Iterator Bytes
.IS_LINE_JUNK: Str -> Bool
.IS_CHARACTER_JUNK: Str -> Bool

View file

@ -0,0 +1,6 @@
f x: Int, y: Int := 1, z: Nat := 2 = x + y + z
_: (Int, Str) -> Int = f # ERR
_: (Int, Int, Int) -> Int = f # ERR (contravariant)
id_or_int x := 1 = x
_: Int -> Str = id_or_int # ERR

View file

@ -6,3 +6,5 @@ assert first(1, 2, 3) == "b" # ERR
f = (*_: Int) -> None
f "a", 1, 2
_: (Int, Str) -> NoneType = f # ERR

View file

@ -9,3 +9,12 @@ i = id_or_int()
s = id_or_int "a"
assert i + 1 + 1 == 3
assert s + "b" == "ab"
_: (Int, y := Int, z := Nat) -> Int = f
_: (Int, y := Int) -> Int = f
_: Int -> Int = f
_: (Int, Int) -> Int = f
_: (Int, Int, Nat) -> Int = f
_: (Int, Int, _: {1}) -> Int = f
_: Int -> Int = id_or_int
_: Str -> Str = id_or_int

View file

@ -10,3 +10,7 @@ assert sum_(0, 1, 2, 3) == 6
f = (*_: Int) -> None
f(1, 2, 3)
_: Int -> NoneType = f
_: (Int, Int) -> NoneType = f
_: (Int, Nat) -> NoneType = f

View file

@ -522,6 +522,11 @@ fn exec_collection_err() -> Result<(), ()> {
expect_failure("tests/should_err/collection.er", 0, 5)
}
#[test]
fn exec_default_param_err() -> Result<(), ()> {
expect_failure("tests/should_err/default_param.er", 0, 3)
}
#[test]
fn exec_dependent_err() -> Result<(), ()> {
expect_failure("tests/should_err/dependent.er", 0, 5)
@ -696,7 +701,7 @@ fn exec_refinement_class_err() -> Result<(), ()> {
#[test]
fn exec_var_args_err() -> Result<(), ()> {
expect_failure("tests/should_err/var_args.er", 0, 3)
expect_failure("tests/should_err/var_args.er", 0, 4)
}
#[test]