fix: {default, variable} parameter bugs

This commit is contained in:
Shunsuke Shibayama 2024-04-24 15:24:05 +09:00
parent dc9cbf98a7
commit 3d7283cb01
8 changed files with 131 additions and 16 deletions

View file

@ -1,4 +1,5 @@
//! provides type-comparison
use std::iter::repeat;
use std::option::Option; // conflicting to Type::Option
use erg_common::consts::DEBUG_MODE;
@ -357,6 +358,7 @@ impl Context {
match (lhs, rhs) {
// Proc :> Func if params are compatible
// * default params can be omitted (e.g. (Int, x := Int) -> Int <: (Int) -> Int)
// * and default params can be non-default (e.g. (Int, x := Int) -> Int <: (Int, Int) -> Int)
(Subr(ls), Subr(rs)) if ls.kind == rs.kind || ls.kind.is_proc() => {
let default_check = || {
for lpt in ls.default_params.iter() {
@ -376,8 +378,12 @@ impl Context {
};
// () -> Never <: () -> Int <: () -> Object
// (Object) -> Int <: (Int) -> Int <: (Never) -> Int
// (Int, n := Int) -> Int <: (Int, Int) -> Int
// (Int, n := Int, m := Int) -> Int <: (Int, Int) -> Int
// (Int, n := Int) -> Int <!: (Int, Int, Int) -> Int
// (*Int) -> Int <: (Int, Int) -> Int
let same_params_len = ls.non_default_params.len() == rs.non_default_params.len()
let len_judge = ls.non_default_params.len()
<= rs.non_default_params.len() + rs.default_params.len()
|| rs.var_params.is_some();
// && ls.default_params.len() <= rs.default_params.len();
let rhs_ret = rs
@ -385,18 +391,24 @@ impl Context {
.clone()
.replace_params(rs.param_names(), ls.param_names());
let return_t_judge = self.supertype_of(&ls.return_t, &rhs_ret); // covariant
let non_defaults_judge = ls
.non_default_params
.iter()
.zip(rs.non_default_params.iter())
.all(|(l, r)| self.subtype_of(l.typ(), r.typ()));
let non_defaults_judge = if let Some(r_var) = rs.var_params.as_deref() {
ls.non_default_params
.iter()
.zip(repeat(r_var))
.all(|(l, r)| self.subtype_of(l.typ(), r.typ()))
} else {
ls.non_default_params
.iter()
.zip(rs.non_default_params.iter().chain(rs.default_params.iter()))
.all(|(l, r)| self.subtype_of(l.typ(), r.typ()))
};
let var_params_judge = ls
.var_params
.as_ref()
.zip(rs.var_params.as_ref())
.map(|(l, r)| self.subtype_of(l.typ(), r.typ()))
.unwrap_or(true);
same_params_len
len_judge
&& return_t_judge
&& non_defaults_judge
&& var_params_judge

View file

@ -1,4 +1,5 @@
//! provides type variable related operations
use std::iter::repeat;
use std::mem;
use std::option::Option;
@ -1252,14 +1253,28 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
}
(Subr(sub_subr), Subr(sup_subr)) => {
sub_subr
.non_default_params
.iter()
.zip(sup_subr.non_default_params.iter())
.try_for_each(|(sub, sup)| {
// contravariant
self.sub_unify(sup.typ(), sub.typ())
})?;
// (Int, *Int) -> ... <: (T, U, V) -> ...
if let Some(sub_var) = sub_subr.var_params.as_deref() {
sub_subr
.non_default_params
.iter()
.chain(repeat(sub_var))
.zip(sup_subr.non_default_params.iter())
.try_for_each(|(sub, sup)| {
// contravariant
self.sub_unify(sup.typ(), sub.typ())
})?;
} else {
sub_subr
.non_default_params
.iter()
.chain(sub_subr.default_params.iter())
.zip(sup_subr.non_default_params.iter())
.try_for_each(|(sub, sup)| {
// contravariant
self.sub_unify(sup.typ(), sub.typ())
})?;
}
sub_subr
.var_params
.iter()

View file

@ -0,0 +1,62 @@
.Differ: ClassType
.HTMLDiff: ClassType
.SequenceMatcher: ClassType
.context_diff: (
a: Sequence(Str),
b: Sequence(Str),
fromfile := Str,
tofile := Str,
fromfiledate := Str,
tofiledate := Str,
n := Nat,
lineterm := Str,
) -> Iterator Str
.get_close_matches: (
word: Str,
possibilities: Sequence(Str),
n := Nat,
cutoff := Float,
) -> [Str; _]
.ndiff: (
a: Sequence(Str),
b: Sequence(Str),
linejunk := Str -> Bool,
charjunk := Str -> Bool,
) -> Iterator Str
.restore: (sequence: Iterator(Str), which: Nat) -> Str
.unified_diff: (
a: Sequence(Str),
b: Sequence(Str),
fromfile := Str,
tofile := Str,
fromfiledate := Str,
tofiledate := Str,
n := Nat,
lineterm := Str,
) -> Iterator Str
.diff_bytes: (
dfunc: ((
a: Sequence(Str),
b: Sequence(Str),
fromfile: Str,
tofile: Str,
fromfiledate: Str,
tofiledate: Str,
n: Nat,
lineterm: Str,
) -> Iterator Str),
a: Sequence(Bytes),
b: Sequence(Bytes),
fromfile := Bytes,
tofile := Bytes,
fromfiledate := Bytes,
tofiledate := Bytes,
n := Nat,
lineterm := Bytes,
) -> Iterator Bytes
.IS_LINE_JUNK: Str -> Bool
.IS_CHARACTER_JUNK: Str -> Bool