diff --git a/compiler/mono/src/reset_reuse.rs b/compiler/mono/src/reset_reuse.rs index a19b32113a..bb5d73c168 100644 --- a/compiler/mono/src/reset_reuse.rs +++ b/compiler/mono/src/reset_reuse.rs @@ -114,23 +114,21 @@ fn function_s<'a, 'i>( let body: &Stmt = *body; let new_body = function_s(env, w, c, body); - let new_join = if std::ptr::eq(body, new_body) || body == new_body { - // the join point body will consume w - Join { - id, - parameters, - body: new_body, - remainder, - } - } else { + let new_join = if body == new_body { let new_remainder = function_s(env, w, c, remainder); - Join { id, parameters, body, remainder: new_remainder, } + } else { + Join { + id, + parameters, + body: new_body, + remainder, + } }; arena.alloc(new_join) @@ -209,7 +207,7 @@ fn try_function_s<'a, 'i>( let new_stmt = function_s(env, w, c, stmt); - if std::ptr::eq(stmt, new_stmt) || stmt == new_stmt { + if stmt == new_stmt { stmt } else { insert_reset(env, w, x, Layout::Union(c.layout), new_stmt) @@ -298,11 +296,16 @@ fn function_d_main<'a, 'i>( _ => { let (b, found) = function_d_main(env, x, c, continuation); + // NOTE the &b != continuation is not found in the Lean source, but is required + // otherwise we observe the same symbol being reset twice let mut result = MutSet::default(); - if found || { - occurring_variables_expr(expr, &mut result); - !result.contains(&x) - } { + if found + || { + occurring_variables_expr(expr, &mut result); + !result.contains(&x) + } + || &b != continuation + { let let_stmt = Let(*symbol, expr.clone(), *layout, b); (arena.alloc(let_stmt), found) @@ -413,10 +416,10 @@ fn function_d_main<'a, 'i>( let (b, found) = function_d_main(env, x, c, remainder); - let (v, _found) = function_d_main(env, x, c, body); - env.jp_live_vars.remove(id); + let (v, _found) = function_d_main(env, x, c, body); + // If `found' == true`, then `Dmain b` must also have returned `(b, true)` since // we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`), // then it must also live in `b` since `j` is reachable from `b` with a `jmp`. @@ -463,10 +466,6 @@ fn function_r<'a, 'i>(env: &mut Env<'a, 'i>, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> let mut new_branches = Vec::with_capacity_in(branches.len(), arena); // TODO for non-recursive unions there is no benefit - let benefits_from_reuse = match cond_layout { - Layout::Union(union_layout) => Some(union_layout), - _ => None, - }; for (tag, info, body) in branches.iter() { let temp = function_r(env, body); @@ -478,6 +477,7 @@ fn function_r<'a, 'i>(env: &mut Env<'a, 'i>, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> layout, tag_id, } => match layout { + Layout::Union(UnionLayout::NonRecursive(_)) => temp, Layout::Union(union_layout) if !union_layout.tag_is_null(*tag_id) => { let ctor_info = CtorInfo { layout: *union_layout, @@ -503,6 +503,7 @@ fn function_r<'a, 'i>(env: &mut Env<'a, 'i>, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> layout, tag_id, } => match layout { + Layout::Union(UnionLayout::NonRecursive(_)) => temp, Layout::Union(union_layout) if !union_layout.tag_is_null(*tag_id) => { let ctor_info = CtorInfo { layout: *union_layout, @@ -536,14 +537,13 @@ fn function_r<'a, 'i>(env: &mut Env<'a, 'i>, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> } => { env.jp_live_vars.insert(*id, LiveVarSet::default()); - let body_live_vars = collect_stmt(body, &env.jp_live_vars, LiveVarSet::default()); + let v = function_r(env, body); + let body_live_vars = collect_stmt(v, &env.jp_live_vars, LiveVarSet::default()); env.jp_live_vars.insert(*id, body_live_vars); let b = function_r(env, remainder); - let v = function_r(env, body); - env.jp_live_vars.remove(id); let join = Join {