diff --git a/compiler/gen/tests/test_gen.rs b/compiler/gen/tests/test_gen.rs index d4671c63fd..44de4820ba 100644 --- a/compiler/gen/tests/test_gen.rs +++ b/compiler/gen/tests/test_gen.rs @@ -1329,6 +1329,44 @@ mod test_gen { ); } + /* + #[test] + fn pair_with_guard_pattern() { + assert_evals_to!( + indoc!( + r#" + when Pair 2 3 is + Pair 4 _ -> 9 + Pair 3 _ -> 9 + Pair a b -> a + b + "# + ), + 5, + i64 + ); + } + */ + + #[test] + fn result_with_guard_pattern() { + // This test revealed an issue with hashing Test values + assert_evals_to!( + indoc!( + r#" + x : Result Int Int + x = Ok 2 + + when x is + Ok 3 -> 1 + Ok _ -> 2 + Err _ -> 3 + "# + ), + 2, + i64 + ); + } + #[test] fn maybe_is_just() { assert_evals_to!( diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index 7ccb53a59b..b3deaa4e6e 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -41,7 +41,7 @@ pub enum DecisionTree<'a> { }, } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum Test<'a> { IsCtor { tag_id: u8, @@ -59,6 +59,41 @@ pub enum Test<'a> { num_alts: usize, }, } +use std::hash::{Hash, Hasher}; +impl<'a> Hash for Test<'a> { + fn hash(&self, state: &mut H) { + use Test::*; + + match self { + IsCtor { tag_id, .. } => { + state.write_u8(0); + tag_id.hash(state); + // The point of this custom implementation is to not hash the tag arguments + } + IsInt(v) => { + state.write_u8(1); + v.hash(state); + } + IsFloat(v) => { + state.write_u8(2); + v.hash(state); + } + IsStr(v) => { + state.write_u8(3); + v.hash(state); + } + IsBit(v) => { + state.write_u8(4); + v.hash(state); + } + IsByte { tag_id, num_alts } => { + state.write_u8(5); + tag_id.hash(state); + num_alts.hash(state) + } + } + } +} #[derive(Clone, Debug, PartialEq)] pub enum Path { @@ -156,21 +191,22 @@ fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path, let path = path_pattern.0; // Theory: unbox doesn't have any value for us, because one-element tag unions // don't store the tag anyway. - // if arguments.len() == 1 { - // path_patterns.push((Path::Unbox(Box::new(path)), path_pattern.1.clone())); - // } else { - for (index, (arg_pattern, _)) in arguments.iter().enumerate() { - flatten( - ( - Path::Index { - index: index as u64, - tag_id: *tag_id, - path: Box::new(path.clone()), - }, - arg_pattern.clone(), - ), - path_patterns, - ); + if arguments.len() == 1 { + path_patterns.push((Path::Unbox(Box::new(path)), path_pattern.1.clone())); + } else { + for (index, (arg_pattern, _)) in arguments.iter().enumerate() { + flatten( + ( + Path::Index { + index: index as u64, + tag_id: *tag_id, + path: Box::new(path.clone()), + }, + arg_pattern.clone(), + ), + path_patterns, + ); + } } } else { path_patterns.push(path_pattern); @@ -239,9 +275,21 @@ fn tests_at_path<'a>(selected_path: &Path, branches: Vec>) -> Vec(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O mut start, found_pattern: pattern, end, - } => match pattern { - Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => Some(branch), + } => { + match pattern { + Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => Some(branch), - RecordDestructure(destructs, _) => match test { - IsCtor { - tag_name: test_name, - tag_id, - .. - } => { - debug_assert!(test_name == &TagName::Global("#Record".into())); - - let sub_positions = - destructs.into_iter().enumerate().map(|(index, destruct)| { - let pattern = if let Some(guard) = destruct.guard { - guard.clone() - } else { - Pattern::Underscore - }; - - ( - Path::Index { - index: index as u64, - tag_id: *tag_id, - path: Box::new(path.clone()), - }, - pattern, - ) - }); - start.extend(sub_positions); - start.extend(end); - - Some(Branch { - goal: branch.goal, - patterns: start, - }) - } - _ => None, - }, - - AppliedTag { - tag_name, - arguments, - .. - } => { - match test { + RecordDestructure(destructs, _) => match test { IsCtor { tag_name: test_name, tag_id, .. - } if &tag_name == test_name => { - // Theory: Unbox doesn't have any value for us - // if arguments.len() == 1 && union.alternatives.len() == 1 { - // let arg = arguments.remove(0); - // { - // start.push((Path::Unbox(Box::new(path.clone())), arg)); - // start.extend(end); - // } - // } else { + } => { + debug_assert!(test_name == &TagName::Global("#Record".into())); let sub_positions = - arguments - .into_iter() - .enumerate() - .map(|(index, (pattern, _))| { - ( - Path::Index { - index: index as u64, - tag_id: *tag_id, - path: Box::new(path.clone()), - }, - pattern, - ) - }); + destructs.into_iter().enumerate().map(|(index, destruct)| { + let pattern = if let Some(guard) = destruct.guard { + guard.clone() + } else { + Pattern::Underscore + }; + + ( + Path::Index { + index: index as u64, + tag_id: *tag_id, + path: Box::new(path.clone()), + }, + pattern, + ) + }); start.extend(sub_positions); start.extend(end); @@ -419,66 +423,111 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O }) } _ => None, + }, + + AppliedTag { + tag_name, + arguments, + union, + .. + } => { + match test { + IsCtor { + tag_name: test_name, + tag_id, + .. + } if &tag_name == test_name => { + // Theory: Unbox doesn't have any value for us + if arguments.len() == 1 && union.alternatives.len() == 1 { + let arg = arguments[0].clone(); + { + start.push((Path::Unbox(Box::new(path.clone())), arg.0)); + start.extend(end); + } + } else { + let sub_positions = arguments.into_iter().enumerate().map( + |(index, (pattern, _))| { + ( + Path::Index { + index: index as u64, + tag_id: *tag_id, + path: Box::new(path.clone()), + }, + pattern, + ) + }, + ); + start.extend(sub_positions); + start.extend(end); + } + + Some(Branch { + goal: branch.goal, + patterns: start, + }) + } + _ => None, + } } + StrLiteral(string) => match test { + IsStr(test_str) if string == *test_str => { + start.extend(end); + Some(Branch { + goal: branch.goal, + patterns: start, + }) + } + _ => None, + }, + + IntLiteral(int) => match test { + IsInt(is_int) if int == *is_int => { + start.extend(end); + Some(Branch { + goal: branch.goal, + patterns: start, + }) + } + _ => None, + }, + + FloatLiteral(float) => match test { + IsFloat(test_float) if float == *test_float => { + start.extend(end); + Some(Branch { + goal: branch.goal, + patterns: start, + }) + } + _ => None, + }, + + BitLiteral(bit) => match test { + IsBit(test_bit) if bit == *test_bit => { + start.extend(end); + Some(Branch { + goal: branch.goal, + patterns: start, + }) + } + _ => None, + }, + + EnumLiteral { tag_id, .. } => match test { + IsByte { + tag_id: test_id, .. + } if tag_id == *test_id => { + start.extend(end); + Some(Branch { + goal: branch.goal, + patterns: start, + }) + } + + _ => None, + }, } - StrLiteral(string) => match test { - IsStr(test_str) if string == *test_str => { - start.extend(end); - Some(Branch { - goal: branch.goal, - patterns: start, - }) - } - _ => None, - }, - - IntLiteral(int) => match test { - IsInt(is_int) if int == *is_int => { - start.extend(end); - Some(Branch { - goal: branch.goal, - patterns: start, - }) - } - _ => None, - }, - - FloatLiteral(float) => match test { - IsFloat(test_float) if float == *test_float => { - start.extend(end); - Some(Branch { - goal: branch.goal, - patterns: start, - }) - } - _ => None, - }, - - BitLiteral(bit) => match test { - IsBit(test_bit) if bit == *test_bit => { - start.extend(end); - Some(Branch { - goal: branch.goal, - patterns: start, - }) - } - _ => None, - }, - - EnumLiteral { tag_id, .. } => match test { - IsByte { - tag_id: test_id, .. - } if tag_id == *test_id => { - start.extend(end); - Some(Branch { - goal: branch.goal, - patterns: start, - }) - } - - _ => None, - }, - }, + } } }