diff --git a/compiler/constrain/src/expr.rs b/compiler/constrain/src/expr.rs index 77b8a3c82c..5584779f03 100644 --- a/compiler/constrain/src/expr.rs +++ b/compiler/constrain/src/expr.rs @@ -9,7 +9,7 @@ use roc_can::expected::PExpected; use roc_can::expr::Expr::{self, *}; use roc_can::expr::{Field, WhenBranch}; use roc_can::pattern::Pattern; -use roc_collections::all::{ImMap, Index, SendMap}; +use roc_collections::all::{ImMap, Index, MutSet, SendMap}; use roc_module::ident::{Lowercase, TagName}; use roc_module::symbol::{ModuleId, Symbol}; use roc_region::all::{Located, Region}; @@ -1438,13 +1438,15 @@ fn instantiate_rigids( annotation: &Type, introduced_vars: &IntroducedVariables, new_rigids: &mut Vec, - ftv: &mut ImMap, + ftv: &mut ImMap, // rigids defined before the current annotation loc_pattern: &Located, headers: &mut SendMap>, ) -> Type { let mut annotation = annotation.clone(); let mut rigid_substitution: ImMap = ImMap::default(); + let outside_rigids: MutSet = ftv.values().copied().collect(); + for (name, var) in introduced_vars.var_by_name.iter() { if let Some(existing_rigid) = ftv.get(name) { rigid_substitution.insert(*var, Type::Variable(*existing_rigid)); @@ -1464,7 +1466,12 @@ fn instantiate_rigids( &Located::at(loc_pattern.region, annotation.clone()), ) { for (symbol, loc_type) in new_headers { - new_rigids.extend(loc_type.value.variables()); + for var in loc_type.value.variables() { + // a rigid is only new if this annotation is the first occurence of this rigid + if !outside_rigids.contains(&var) { + new_rigids.push(var); + } + } headers.insert(symbol, loc_type); } } diff --git a/compiler/test_gen/src/gen_primitives.rs b/compiler/test_gen/src/gen_primitives.rs index d183b95d0b..1de82fa5f7 100644 --- a/compiler/test_gen/src/gen_primitives.rs +++ b/compiler/test_gen/src/gen_primitives.rs @@ -2906,3 +2906,77 @@ fn do_pass_bool_byte_closure_layout() { RocStr ); } + +#[test] +fn nested_rigid_list() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + foo : List a -> List a + foo = \list -> + p2 : List a + p2 = list + + p2 + + main = + when foo [] is + _ -> "hello world" + "# + ), + RocStr::from_slice(b"hello world"), + RocStr + ); +} + +#[test] +fn nested_rigid_alias() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + Identity a : [ @Identity a ] + + foo : Identity a -> Identity a + foo = \list -> + p2 : Identity a + p2 = list + + p2 + + main = + when foo (@Identity "foo") is + _ -> "hello world" + "# + ), + RocStr::from_slice(b"hello world"), + RocStr + ); +} + +#[test] +fn nested_rigid_tag_union() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + foo : [ @Identity a ] -> [ @Identity a ] + foo = \list -> + p2 : [ @Identity a ] + p2 = list + + p2 + + main = + when foo (@Identity "foo") is + _ -> "hello world" + "# + ), + RocStr::from_slice(b"hello world"), + RocStr + ); +}