fix equality/hash issue

lead to tests not being merged/shared
This commit is contained in:
Folkert 2020-03-20 16:23:56 +01:00
parent e062404a63
commit d0da300042
2 changed files with 229 additions and 142 deletions

View file

@ -1329,6 +1329,44 @@ mod test_gen {
);
}
/*
#[test]
fn pair_with_guard_pattern() {
assert_evals_to!(
indoc!(
r#"
when Pair 2 3 is
Pair 4 _ -> 9
Pair 3 _ -> 9
Pair a b -> a + b
"#
),
5,
i64
);
}
*/
#[test]
fn result_with_guard_pattern() {
// This test revealed an issue with hashing Test values
assert_evals_to!(
indoc!(
r#"
x : Result Int Int
x = Ok 2
when x is
Ok 3 -> 1
Ok _ -> 2
Err _ -> 3
"#
),
2,
i64
);
}
#[test]
fn maybe_is_just() {
assert_evals_to!(

View file

@ -41,7 +41,7 @@ pub enum DecisionTree<'a> {
},
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Test<'a> {
IsCtor {
tag_id: u8,
@ -59,6 +59,41 @@ pub enum Test<'a> {
num_alts: usize,
},
}
use std::hash::{Hash, Hasher};
impl<'a> Hash for Test<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
use Test::*;
match self {
IsCtor { tag_id, .. } => {
state.write_u8(0);
tag_id.hash(state);
// The point of this custom implementation is to not hash the tag arguments
}
IsInt(v) => {
state.write_u8(1);
v.hash(state);
}
IsFloat(v) => {
state.write_u8(2);
v.hash(state);
}
IsStr(v) => {
state.write_u8(3);
v.hash(state);
}
IsBit(v) => {
state.write_u8(4);
v.hash(state);
}
IsByte { tag_id, num_alts } => {
state.write_u8(5);
tag_id.hash(state);
num_alts.hash(state)
}
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum Path {
@ -156,9 +191,9 @@ fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path,
let path = path_pattern.0;
// Theory: unbox doesn't have any value for us, because one-element tag unions
// don't store the tag anyway.
// if arguments.len() == 1 {
// path_patterns.push((Path::Unbox(Box::new(path)), path_pattern.1.clone()));
// } else {
if arguments.len() == 1 {
path_patterns.push((Path::Unbox(Box::new(path)), path_pattern.1.clone()));
} else {
for (index, (arg_pattern, _)) in arguments.iter().enumerate() {
flatten(
(
@ -172,6 +207,7 @@ fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path,
path_patterns,
);
}
}
} else {
path_patterns.push(path_pattern);
}
@ -239,9 +275,21 @@ fn tests_at_path<'a>(selected_path: &Path, branches: Vec<Branch<'a>>) -> Vec<Tes
.into_iter()
.filter_map(|b| test_at_path(selected_path, b));
// The rust HashMap also uses equality, here we really want to use the custom hash function
// defined on Test to determine whether a test is unique. So we have to do the hashing
// explicitly
use std::collections::hash_map::DefaultHasher;
for test in all_tests {
if !visited.contains(&test) {
visited.insert(test.clone());
let hash = {
let mut hasher = DefaultHasher::new();
test.hash(&mut hasher);
hasher.finish()
};
if !visited.contains(&hash) {
visited.insert(hash);
unique.push(test);
}
}
@ -338,7 +386,8 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O
mut start,
found_pattern: pattern,
end,
} => match pattern {
} => {
match pattern {
Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => Some(branch),
RecordDestructure(destructs, _) => match test {
@ -348,7 +397,6 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O
..
} => {
debug_assert!(test_name == &TagName::Global("#Record".into()));
let sub_positions =
destructs.into_iter().enumerate().map(|(index, destruct)| {
let pattern = if let Some(guard) = destruct.guard {
@ -380,6 +428,7 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O
AppliedTag {
tag_name,
arguments,
union,
..
} => {
match test {
@ -389,18 +438,15 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O
..
} if &tag_name == test_name => {
// Theory: Unbox doesn't have any value for us
// if arguments.len() == 1 && union.alternatives.len() == 1 {
// let arg = arguments.remove(0);
// {
// start.push((Path::Unbox(Box::new(path.clone())), arg));
// start.extend(end);
// }
// } else {
let sub_positions =
arguments
.into_iter()
.enumerate()
.map(|(index, (pattern, _))| {
if arguments.len() == 1 && union.alternatives.len() == 1 {
let arg = arguments[0].clone();
{
start.push((Path::Unbox(Box::new(path.clone())), arg.0));
start.extend(end);
}
} else {
let sub_positions = arguments.into_iter().enumerate().map(
|(index, (pattern, _))| {
(
Path::Index {
index: index as u64,
@ -409,9 +455,11 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O
},
pattern,
)
});
},
);
start.extend(sub_positions);
start.extend(end);
}
Some(Branch {
goal: branch.goal,
@ -478,7 +526,8 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O
_ => None,
},
},
}
}
}
}