diff --git a/crates/compiler/mono/src/drop_specialization.rs b/crates/compiler/mono/src/drop_specialization.rs index af0312b491..761a8db60d 100644 --- a/crates/compiler/mono/src/drop_specialization.rs +++ b/crates/compiler/mono/src/drop_specialization.rs @@ -677,7 +677,7 @@ fn specialize_drops_stmt<'a, 'i>( let mut body_jump_incremented_symbols = remainder_jump_incremented_symbols; // Perform iteration to get the incremented_symbols for the body. - let joinpoint_usage = loop { + let joinpoint_info = loop { // Update the incremented_symbols to the remainder's incremented_symbols. let mut current_body_environment = body_environment.clone(); current_body_environment.incremented_symbols = @@ -701,18 +701,21 @@ fn specialize_drops_stmt<'a, 'i>( .clone(); if body_jump_incremented_symbols == new_body_jump_incremented_symbols { - break new_body_jump_incremented_symbols; + break ( + new_body_jump_incremented_symbols, + current_body_environment.incremented_symbols, + ); } else { body_jump_incremented_symbols = new_body_jump_incremented_symbols; } }; - let join_joinpoint_usage = arena.alloc(joinpoint_usage.clone()); + let alloced_joinpoint_info = arena.alloc(joinpoint_info.clone()); - body_environment.incremented_symbols = joinpoint_usage; + body_environment.incremented_symbols = joinpoint_info.0; body_environment .join_incremented_symbols - .insert(*id, join_joinpoint_usage); + .insert(*id, alloced_joinpoint_info); let new_body = specialize_drops_stmt( arena, @@ -724,7 +727,7 @@ fn specialize_drops_stmt<'a, 'i>( environment .join_incremented_symbols - .insert(*id, join_joinpoint_usage); + .insert(*id, alloced_joinpoint_info); arena.alloc(Stmt::Join { id: *id, @@ -741,7 +744,7 @@ fn specialize_drops_stmt<'a, 'i>( } Stmt::Jump(joinpoint_id, arguments) => { match environment.join_incremented_symbols.get(joinpoint_id) { - Some(join_usage) => { + Some((join_usage, join_returns)) => { // Consume all symbols that were consumed in the join. for (symbol, count) in join_usage.map.iter() { for _ in 0..*count { @@ -752,6 +755,11 @@ fn specialize_drops_stmt<'a, 'i>( ); } } + for (symbol, count) in join_returns.map.iter() { + environment + .incremented_symbols + .insert_count(*symbol, *count); + } } None => { // No join usage, let the join know the minimum amount of symbols that were incremented from each jump. @@ -1411,7 +1419,7 @@ struct DropSpecializationEnvironment<'a> { jump_incremented_symbols: MutMap>, // A map containing the expected number of symbol increments from joinpoints for a jump. - join_incremented_symbols: MutMap>, + join_incremented_symbols: MutMap, CountingMap)>, } impl<'a> DropSpecializationEnvironment<'a> { diff --git a/crates/compiler/test_mono/generated/drop_specialize_inc_after_jump.txt b/crates/compiler/test_mono/generated/drop_specialize_inc_after_jump.txt new file mode 100644 index 0000000000..dd31b02248 --- /dev/null +++ b/crates/compiler/test_mono/generated/drop_specialize_inc_after_jump.txt @@ -0,0 +1,27 @@ +procedure Bool.2 (): + let Bool.23 : Int1 = true; + ret Bool.23; + +procedure Test.2 (Test.5): + let Test.6 : Int1 = CallByName Bool.2; + let Test.7 : {Str, Str} = StructAtIndex 0 Test.5; + inc 2 Test.7; + joinpoint Test.13 Test.8: + let Test.9 : {{Str, Str}, {{Str, Str}, Str}} = Struct {Test.7, Test.5}; + let Test.11 : {{Str, Str}, {{Str, Str}, {{Str, Str}, Str}}} = Struct {Test.7, Test.9}; + ret Test.11; + in + if Test.6 then + let Test.12 : I64 = 1i64; + jump Test.13 Test.12; + else + let Test.12 : I64 = 0i64; + jump Test.13 Test.12; + +procedure Test.0 (): + let Test.3 : Str = "value"; + inc 2 Test.3; + let Test.14 : {Str, Str} = Struct {Test.3, Test.3}; + let Test.4 : {{Str, Str}, Str} = Struct {Test.14, Test.3}; + let Test.10 : {{Str, Str}, {{Str, Str}, {{Str, Str}, Str}}} = CallByName Test.2 Test.4; + ret Test.10; diff --git a/crates/compiler/test_mono/src/tests.rs b/crates/compiler/test_mono/src/tests.rs index 8ba4b4a0b0..81a1becc4d 100644 --- a/crates/compiler/test_mono/src/tests.rs +++ b/crates/compiler/test_mono/src/tests.rs @@ -3111,3 +3111,26 @@ fn dbg_in_expect() { "### ) } + +#[mono_test] +fn drop_specialize_inc_after_jump() { + indoc!( + r#" + app "test" provides [main] to "./platform" + + Tuple a b : { left : a, right : b } + + main = + v = "value" + t = { left: { left: v, right: v }, right: v } + tupleItem t + + tupleItem = \t -> + true = Bool.true + l = t.left + x = if true then 1 else 0 + t2 = {left: l, right: t} + {left: l, right: t2} + "# + ) +}